Skip to content

Commit eb435dd

Browse files
authored
Merge pull request #316 from ACEsuit/co/etpair
ET Pair Potential Model
2 parents 4277cc8 + efa4d48 commit eb435dd

File tree

8 files changed

+424
-19
lines changed

8 files changed

+424
-19
lines changed

src/et_models/convert.jl

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import EquivariantTensors as ET
66
import Polynomials4ML as P4ML
77

88
import ACEpotentials.Models: LearnableRnlrzzBasis, PolyEnvelope2sX,
9-
_i2z, GeneralizedAgnesiTransform
9+
_i2z, GeneralizedAgnesiTransform, PolyEnvelope1sR
1010

1111
using LinearAlgebra: norm, dot
1212

@@ -68,6 +68,7 @@ function convert2et(model)
6868
end
6969

7070

71+
7172
# In ET we currently store an edge xij as a NamedTuple, e.g,
7273
# xij = (𝐫ij = ..., zi = ..., zj = ...)
7374
# The NTtransform is a wrapper for mapping xij -> y
@@ -85,15 +86,6 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
8586
# number of species
8687
NZ = length(zlist)
8788

88-
# species z -> index i mapping
89-
__z2i = let _i2z = (_i2z = zlist,)
90-
z -> _z2i(_i2z, z)
91-
end
92-
93-
# __zz2i maps a `(Zi, Zj)` pair to a single index `a` representing
94-
# (Zi, Zj) in a flattened array
95-
__zz2ii = (zi, zj) -> (__z2i(zi) - 1) * NZ + __z2i(zj)
96-
9789
selector = let zlist = tuple(zlist...)
9890
xij -> ET.catcat2idx(zlist, xij.z0, xij.z1)
9991
end
@@ -124,8 +116,8 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
124116
@assert env.x1 == -1
125117
@assert env.x2 == 1
126118
end
127-
128119
et_env = y -> (1 - y^2)^2
120+
# et_env = _convert_envelope(basis.envelopes)
129121

130122
# the polynomial basis just stays the same
131123
# but needs to be wrapped due to the envelope being applied
@@ -205,4 +197,108 @@ function _convert_agnesi(rbasis::LearnableRnlrzzBasis)
205197
end
206198

207199
return ET.NTtransformST(f_agnesi, st)
208-
end
200+
end
201+
202+
203+
function _convert_envelope(envelopes)
204+
TENV = typeof(envelopes[1])
205+
for env in envelopes
206+
@assert typeof(env) == TENV
207+
end
208+
209+
@show TENV
210+
return _convert_env_TENV(TENV, envelopes)
211+
end
212+
213+
function _convert_env_TENV(::Type{<: PolyEnvelope2sX}, envelopes)
214+
for env in envelopes
215+
@assert env isa PolyEnvelope2sX
216+
@assert env.p1 == env.p2 == 2
217+
@assert env.x1 == -1
218+
@assert env.x2 == 1
219+
end
220+
return y -> (1 - y^2)^2
221+
end
222+
223+
function _convert_env_TENV(::Type{<: PolyEnvelope1sR}, envelopes)
224+
env1 = envelopes[1]
225+
for env in envelopes
226+
@assert env == env1
227+
end
228+
f_env = (r, st) -> _eval_env_1sr(r, st.rcut, st.p)
229+
refst = ( rcut = env1.rcut, p = env1.p )
230+
return ET.st_transform(f_env, refst)
231+
end
232+
233+
function _eval_env_1sr(r, rcut, p)
234+
_1 = one(r)
235+
s = r / rcut
236+
return (s^(-p) - _1) * (_1 - s) * (s < _1)
237+
end
238+
239+
function _convert_pair_envelope(envelopes)
240+
TENV = typeof(envelopes[1])
241+
for env in envelopes
242+
@assert typeof(env) == TENV
243+
end
244+
env1 = envelopes[1]
245+
@assert env1 isa PolyEnvelope1sR
246+
for env in envelopes
247+
@assert env == env1
248+
end
249+
refst = ( rcut = env1.rcut, p = env1.p )
250+
f_env = ET.dp_transform( (x, st) -> _eval_env_1sr( norm(x.𝐫), st.rcut, st.p ),
251+
refst )
252+
return f_env
253+
end
254+
255+
256+
257+
function convertpair(model)
258+
259+
# extract radial basis information
260+
basis = model.pairbasis
261+
zlist = ChemicalSpecies.(basis._i2z)
262+
NZ = length(zlist)
263+
264+
# this construction is a little different from the Rnl basis for the
265+
# many-body model because the envelope takes a different input
266+
# and this makes life a little more complicated.
267+
268+
# 1: polynomials without the envelope
269+
#
270+
dp_agnesi = _convert_agnesi(basis)
271+
polys = basis.polys
272+
selector2 = let zlist = zlist
273+
xij -> ET.catcat2idx(zlist, xij.z0, xij.z1)
274+
end
275+
et_linl = ET.SelectLinL(length(polys), # indim
276+
length(basis), # outdim
277+
NZ^2, # num (Zi,Zj) pairs
278+
selector2)
279+
rbasis_1 = ET.EmbedDP(dp_agnesi, polys, et_linl)
280+
281+
# 2: envelope
282+
dp_envelope = _convert_pair_envelope(basis.envelopes)
283+
# _env_r = _convert_envelope(basis.envelopes)
284+
# dp_envelope = ET.dp_transform( (x, st) -> _env_r.f( norm(x.𝐫), st ),
285+
# _env_r.refstate )
286+
287+
# 3. combine into the radial basis
288+
rembed = ET.EdgeEmbed( EnvRBranchL(dp_envelope, rbasis_1) )
289+
290+
# 4. rembed provides the radial basis for the pair model, now we just
291+
# need the readout layer which is similar to before.
292+
selector1 = let zlist = zlist
293+
x -> ET.cat2idx(zlist, x.z)
294+
end
295+
readout = ET.SelectLinL(
296+
length(basis),
297+
1, # output dim (only one site energy per atom)
298+
NZ, # number of categories = num species
299+
selector1)
300+
301+
et_pair = ETPairModel(rembed, readout)
302+
303+
return et_pair
304+
end

src/et_models/et_envbranch.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
3+
using ConcreteStructs
4+
import Polynomials4ML: evaluate, evaluate_ed
5+
import LuxCore: AbstractLuxContainerLayer
6+
import ChainRulesCore: NoTangent, rrule, unthunk
7+
8+
"""
9+
struct EnvRBranchL
10+
11+
An auxiliary layer that is basically a branch layer needed to build
12+
radial bases, with additional evaluate_ed functionality, needed for
13+
Jacobians.
14+
"""
15+
@concrete struct EnvRBranchL <: AbstractLuxContainerLayer{(:envelope, :rbasis)}
16+
envelope
17+
rbasis
18+
end
19+
20+
(l::EnvRBranchL)(X, ps, st) = _apply_envrbranchl(l, X, ps, st), st
21+
22+
evaluate(l::EnvRBranchL, X, ps, st) = l(X, ps, st)
23+
24+
function _apply_envrbranchl(l::EnvRBranchL, X, ps, st)
25+
ee, _ = l.envelope(X, ps.envelope, st.envelope)
26+
P, _ = l.rbasis(X, ps.rbasis, st.rbasis)
27+
return ee .* P
28+
end
29+
30+
function evaluate_ed(l::EnvRBranchL, X, ps, st)
31+
(ee, d_ee), _ = evaluate_ed(l.envelope, X, ps.envelope, st.envelope)
32+
(P, d_P), _ = evaluate_ed(l.rbasis,X, ps.rbasis, st.rbasis)
33+
34+
# product rule
35+
pP = ee .* P
36+
∂_pP = d_ee .* P .+ ee .* d_P
37+
38+
return (pP, ∂_pP), st
39+
end
40+
41+
function rrule(::typeof(_apply_envrbranchl),
42+
l::EnvRBranchL, X, ps, st)
43+
44+
(P, dP), st = evaluate_ed(l, X, ps, st)
45+
46+
function _pb_embeddp(_∂P)
47+
∂P = unthunk(_∂P)
48+
∂X = dropdims( sum(∂P .* dP, dims = 2), dims = 2)
49+
return NoTangent(), NoTangent(), ∂X, NoTangent(), NoTangent()
50+
end
51+
52+
return P, _pb_embeddp
53+
end
54+

src/et_models/et_models.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11

22
module ETModels
33

4+
# utility layers : these should likely be moved into ET or be removed
5+
# if more convenient implementations can be found.
6+
#
7+
include("et_envbranch.jl")
8+
9+
# ET based ACE model components
410
include("et_ace.jl")
511
include("onebody.jl")
12+
include("et_pair.jl")
613

14+
# converstion utilities: convert from 0.8 style ACE models to ET based models
715
include("convert.jl")
816

917

src/et_models/et_pair.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#
2+
# This is a temporary model implementation needed due to the fact that
3+
# ETACEModel has Rnl, Ylm hard-coded. In the future it could be tested
4+
# whether the pair model could simply be taken as another ACE model
5+
# with a single embedding rather than several, This would need generalization
6+
# of a fair few methods in both ACEpotentials and EquivariantTensors.
7+
#
8+
9+
10+
import EquivariantTensors as ET
11+
import Zygote
12+
import LuxCore: AbstractLuxContainerLayer
13+
using ConcreteStructs: @concrete
14+
15+
16+
@concrete struct ETPairModel <: AbstractLuxContainerLayer{(:rembed, :readout)}
17+
rembed # radial embedding layer = basis
18+
readout # normally a selectlinl readout layer
19+
end
20+
21+
22+
(l::ETPairModel)(X::ET.ETGraph, ps, st) = _apply_etpairmodel(l, X, ps, st), st
23+
24+
25+
function _apply_etpairmodel(l::ETPairModel, X::ET.ETGraph, ps, st)
26+
# evaluate the basis
27+
𝔹 = site_basis(l, X, ps, st)
28+
29+
# readout layer
30+
φ, _ = l.readout((𝔹, X.node_data), ps.readout, st.readout)
31+
32+
return φ
33+
end
34+
35+
# -----------------------------------------------------------
36+
37+
38+
function site_grads(l::ETPairModel, X::ET.ETGraph, ps, st)
39+
∂X = Zygote.gradient( X -> sum(_apply_etpairmodel(l, X, ps, st)), X)[1]
40+
return ∂X
41+
end
42+
43+
44+
# -----------------------------------------------------------
45+
# basis and jacobian evaluation
46+
47+
48+
function site_basis(l::ETPairModel, X::ET.ETGraph, ps, st)
49+
# embed edges
50+
Rnl, _ = l.rembed(X, ps.rembed, st.rembed)
51+
52+
# the basis is obtain by summing over the neighbours of each node,
53+
# which is just a sum over the first dimension of Rnl
54+
𝔹 = dropdims(sum(Rnl, dims=1), dims=1)
55+
56+
return 𝔹
57+
end
58+
59+
60+
function site_basis_jacobian(l::ETPairModel, X::ET.ETGraph, ps, st)
61+
(R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed)
62+
𝔹 = dropdims(sum(R, dims=1), dims=1)
63+
# ∂𝔹 == ∂R
64+
return 𝔹, ∂R
65+
end
66+

test/etmodels/test_etbackend.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# using Pkg; Pkg.activate(joinpath(@__DIR__(), ".."))
1+
# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
22
# using TestEnv; TestEnv.activate();
3-
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "EquivariantTensors.jl"))
3+
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl"))
44
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl"))
55
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles"))
66

@@ -190,7 +190,7 @@ WW = et_ps_2.readout.W
190190
println_slim(@test 𝔹1 𝔹2)
191191
Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ]
192192
Ei_b = et_model_2(G, et_ps_2, et_st_2)[1][:]
193-
println(@test Ei_a Ei_b)
193+
println_slim(@test Ei_a Ei_b)
194194

195195
##
196196

@@ -240,8 +240,6 @@ println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 )
240240
241241
##
242242
# gradients on GPU
243-
# currently failing because somehow the transform is still
244-
# accessing some Float64 values somewhere ....
245243
246244
@info("Check Evaluation of gradient on GPU")
247245
g1 = ETM.site_grads(et_model_2, G_32, ps_32_2, st_32_2)
@@ -268,6 +266,8 @@ println_slim( @test 𝔹1 ≈ 𝔹2 )
268266
269267
println_slim( @test 𝔹1 ≈ 𝔹2 )
270268
err_jac = norm.(∂𝔹1 - ∂𝔹2) ./ (norm.(∂𝔹1) + norm.(∂𝔹2) .+ 0.1)
271-
println_slim( @test maximum(err_jac) < 1e-5 )
269+
println_slim( @test maximum(err_jac) < 1e-4 )
270+
@show maximum(err_jac)
271+
@info("The jacobian error feels a bit large. This may need further investigation.")
272272
273-
=#
273+
=#

test/etmodels/test_etonebody.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ G_dev = dev(G_32)
139139
E1, st = et_V0(G_32, ps_32, st_32)
140140
E2_dev, st_dev = et_V0(G_dev, ps_dev, st_dev)
141141
E2 = Array(E2_dev)
142+
# TODO: add E1 ≈ E2 test??
142143
143144
g1 = ETM.site_grads(et_V0, G_32, ps_32, st_32)
144145
g2_dev = ETM.site_grads(et_V0, G_dev, ps_dev, st_dev)

0 commit comments

Comments
 (0)