|
40 | 40 |
|
41 | 41 | function site_grads(l::ETPairModel, X::ET.ETGraph, ps, st) |
42 | 42 | # Use evaluate_ed to get basis and derivatives, avoiding Zygote thunk issues |
| 43 | + # (Zygote has InplaceableThunk issues with upstream EdgeEmbed rrule) |
43 | 44 | (R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed) |
44 | 45 |
|
45 | | - # R has shape (maxneigs, nnodes, nbasis) after embedding |
46 | | - # 𝔹 = sum over neighbours: shape (nnodes, nbasis) |
47 | | - 𝔹 = dropdims(sum(R, dims=1), dims=1) |
48 | | - |
49 | | - # Get readout weights |
| 46 | + # Get readout weights and species indices |
50 | 47 | iZ = l.readout.selector.(X.node_data) |
51 | 48 | WW = ps.readout.W |
52 | 49 |
|
53 | | - # ∂E/∂R = W[1, :, iZ[i]] for each node, broadcast over neighbours |
54 | | - # ∂R has shape (maxneigs, nnodes, nbasis) |
55 | | - nnodes = length(X.node_data) |
56 | | - ∂E_∂𝔹 = reduce(hcat, WW[1, :, iZ[i]] for i in 1:nnodes)' # (nnodes, nbasis) |
57 | | - |
58 | | - # ∂E/∂R[j, i, k] = ∂E/∂𝔹[i, k] (same for all neighbours j) |
59 | | - ∂E_∂R = reshape(∂E_∂𝔹, 1, size(∂E_∂𝔹)...) # (1, nnodes, nbasis) |
60 | | - |
61 | | - # Chain rule: ∂E/∂X = sum over k of (∂E/∂R * ∂R/∂X) |
62 | 50 | # ∂R has shape (maxneigs, nnodes, nbasis), contains VState gradients |
63 | | - ∂E_edges = dropdims(sum(∂E_∂R .* ∂R, dims=3), dims=3) # (maxneigs, nnodes) |
| 51 | + # Compute: ∂E_edges[j, i] = Σₖ WW[1, k, iZ[i]] * ∂R[j, i, k] |
| 52 | + # This is the chain rule through the linear readout |
| 53 | + maxneigs, nnodes, nbasis = size(∂R) |
| 54 | + |
| 55 | + # Compute edge gradients directly without forming intermediate matrix |
| 56 | + # (avoids O(nnodes * nbasis) memory allocation) |
| 57 | + ∂E_edges = zeros(eltype(∂R), maxneigs, nnodes) |
| 58 | + @inbounds for i in 1:nnodes |
| 59 | + iz = iZ[i] |
| 60 | + @inbounds for k in 1:nbasis |
| 61 | + w = WW[1, k, iz] |
| 62 | + @views ∂E_edges[:, i] .+= w .* ∂R[:, i, k] |
| 63 | + end |
| 64 | + end |
64 | 65 |
|
65 | 66 | # Reshape to match edge_data format |
66 | 67 | ∂E_edges_vec = ET.rev_reshape_embedding(∂E_edges, X) |
|
0 commit comments