Skip to content

Commit f50254b

Browse files
jameskermodeclaude
andcommitted
Improve memory efficiency in ETPairModel site_grads
Address moderator concern about commit 50ed668: - Avoid forming O(nnodes * nbasis) dense intermediate matrix - Compute edge gradients directly using loops - Same numerical results, better memory characteristics 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 84f7bc4 commit f50254b

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/et_models/et_pair.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,28 @@ end
4040

4141
function site_grads(l::ETPairModel, X::ET.ETGraph, ps, st)
4242
# Use evaluate_ed to get basis and derivatives, avoiding Zygote thunk issues
43+
# (Zygote has InplaceableThunk issues with upstream EdgeEmbed rrule)
4344
(R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed)
4445

45-
# R has shape (maxneigs, nnodes, nbasis) after embedding
46-
# 𝔹 = sum over neighbours: shape (nnodes, nbasis)
47-
𝔹 = dropdims(sum(R, dims=1), dims=1)
48-
49-
# Get readout weights
46+
# Get readout weights and species indices
5047
iZ = l.readout.selector.(X.node_data)
5148
WW = ps.readout.W
5249

53-
# ∂E/∂R = W[1, :, iZ[i]] for each node, broadcast over neighbours
54-
# ∂R has shape (maxneigs, nnodes, nbasis)
55-
nnodes = length(X.node_data)
56-
∂E_∂𝔹 = reduce(hcat, WW[1, :, iZ[i]] for i in 1:nnodes)' # (nnodes, nbasis)
57-
58-
# ∂E/∂R[j, i, k] = ∂E/∂𝔹[i, k] (same for all neighbours j)
59-
∂E_∂R = reshape(∂E_∂𝔹, 1, size(∂E_∂𝔹)...) # (1, nnodes, nbasis)
60-
61-
# Chain rule: ∂E/∂X = sum over k of (∂E/∂R * ∂R/∂X)
6250
# ∂R has shape (maxneigs, nnodes, nbasis), contains VState gradients
63-
∂E_edges = dropdims(sum(∂E_∂R .* ∂R, dims=3), dims=3) # (maxneigs, nnodes)
51+
# Compute: ∂E_edges[j, i] = Σₖ WW[1, k, iZ[i]] * ∂R[j, i, k]
52+
# This is the chain rule through the linear readout
53+
maxneigs, nnodes, nbasis = size(∂R)
54+
55+
# Compute edge gradients directly without forming intermediate matrix
56+
# (avoids O(nnodes * nbasis) memory allocation)
57+
∂E_edges = zeros(eltype(∂R), maxneigs, nnodes)
58+
@inbounds for i in 1:nnodes
59+
iz = iZ[i]
60+
@inbounds for k in 1:nbasis
61+
w = WW[1, k, iz]
62+
@views ∂E_edges[:, i] .+= w .* ∂R[:, i, k]
63+
end
64+
end
6465

6566
# Reshape to match edge_data format
6667
∂E_edges_vec = ET.rev_reshape_embedding(∂E_edges, X)

0 commit comments

Comments
 (0)