@@ -6,7 +6,7 @@ import EquivariantTensors as ET
66import Polynomials4ML as P4ML
77
88import ACEpotentials. Models: LearnableRnlrzzBasis, PolyEnvelope2sX,
9- _i2z, GeneralizedAgnesiTransform
9+ _i2z, GeneralizedAgnesiTransform, PolyEnvelope1sR
1010
1111using LinearAlgebra: norm, dot
1212
@@ -68,6 +68,7 @@ function convert2et(model)
6868end
6969
7070
71+
7172# In ET we currently store an edge xij as a NamedTuple, e.g,
7273# xij = (𝐫ij = ..., zi = ..., zj = ...)
7374# The NTtransform is a wrapper for mapping xij -> y
@@ -85,15 +86,6 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
8586 # number of species
8687 NZ = length (zlist)
8788
88- # species z -> index i mapping
89- __z2i = let _i2z = (_i2z = zlist,)
90- z -> _z2i (_i2z, z)
91- end
92-
93- # __zz2i maps a `(Zi, Zj)` pair to a single index `a` representing
94- # (Zi, Zj) in a flattened array
95- __zz2ii = (zi, zj) -> (__z2i (zi) - 1 ) * NZ + __z2i (zj)
96-
9789 selector = let zlist = tuple (zlist... )
9890 xij -> ET. catcat2idx (zlist, xij. z0, xij. z1)
9991 end
@@ -124,8 +116,8 @@ function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z),
124116 @assert env. x1 == - 1
125117 @assert env. x2 == 1
126118 end
127-
128119 et_env = y -> (1 - y^ 2 )^ 2
120+ # et_env = _convert_envelope(basis.envelopes)
129121
130122 # the polynomial basis just stays the same
131123 # but needs to be wrapped due to the envelope being applied
@@ -205,4 +197,108 @@ function _convert_agnesi(rbasis::LearnableRnlrzzBasis)
205197 end
206198
207199 return ET. NTtransformST (f_agnesi, st)
208- end
200+ end
201+
202+
203+ function _convert_envelope (envelopes)
204+ TENV = typeof (envelopes[1 ])
205+ for env in envelopes
206+ @assert typeof (env) == TENV
207+ end
208+
209+ @show TENV
210+ return _convert_env_TENV (TENV, envelopes)
211+ end
212+
213+ function _convert_env_TENV (:: Type{<: PolyEnvelope2sX} , envelopes)
214+ for env in envelopes
215+ @assert env isa PolyEnvelope2sX
216+ @assert env. p1 == env. p2 == 2
217+ @assert env. x1 == - 1
218+ @assert env. x2 == 1
219+ end
220+ return y -> (1 - y^ 2 )^ 2
221+ end
222+
223+ function _convert_env_TENV (:: Type{<: PolyEnvelope1sR} , envelopes)
224+ env1 = envelopes[1 ]
225+ for env in envelopes
226+ @assert env == env1
227+ end
228+ f_env = (r, st) -> _eval_env_1sr (r, st. rcut, st. p)
229+ refst = ( rcut = env1. rcut, p = env1. p )
230+ return ET. st_transform (f_env, refst)
231+ end
232+
233+ function _eval_env_1sr (r, rcut, p)
234+ _1 = one (r)
235+ s = r / rcut
236+ return (s^ (- p) - _1) * (_1 - s) * (s < _1)
237+ end
238+
239+ function _convert_pair_envelope (envelopes)
240+ TENV = typeof (envelopes[1 ])
241+ for env in envelopes
242+ @assert typeof (env) == TENV
243+ end
244+ env1 = envelopes[1 ]
245+ @assert env1 isa PolyEnvelope1sR
246+ for env in envelopes
247+ @assert env == env1
248+ end
249+ refst = ( rcut = env1. rcut, p = env1. p )
250+ f_env = ET. dp_transform ( (x, st) -> _eval_env_1sr ( norm (x. 𝐫), st. rcut, st. p ),
251+ refst )
252+ return f_env
253+ end
254+
255+
256+
257+ function convertpair (model)
258+
259+ # extract radial basis information
260+ basis = model. pairbasis
261+ zlist = ChemicalSpecies .(basis. _i2z)
262+ NZ = length (zlist)
263+
264+ # this construction is a little different from the Rnl basis for the
265+ # many-body model because the envelope takes a different input
266+ # and this makes life a little more complicated.
267+
268+ # 1: polynomials without the envelope
269+ #
270+ dp_agnesi = _convert_agnesi (basis)
271+ polys = basis. polys
272+ selector2 = let zlist = zlist
273+ xij -> ET. catcat2idx (zlist, xij. z0, xij. z1)
274+ end
275+ et_linl = ET. SelectLinL (length (polys), # indim
276+ length (basis), # outdim
277+ NZ^ 2 , # num (Zi,Zj) pairs
278+ selector2)
279+ rbasis_1 = ET. EmbedDP (dp_agnesi, polys, et_linl)
280+
281+ # 2: envelope
282+ dp_envelope = _convert_pair_envelope (basis. envelopes)
283+ # _env_r = _convert_envelope(basis.envelopes)
284+ # dp_envelope = ET.dp_transform( (x, st) -> _env_r.f( norm(x.𝐫), st ),
285+ # _env_r.refstate )
286+
287+ # 3. combine into the radial basis
288+ rembed = ET. EdgeEmbed ( EnvRBranchL (dp_envelope, rbasis_1) )
289+
290+ # 4. rembed provides the radial basis for the pair model, now we just
291+ # need the readout layer which is similar to before.
292+ selector1 = let zlist = zlist
293+ x -> ET. cat2idx (zlist, x. z)
294+ end
295+ readout = ET. SelectLinL (
296+ length (basis),
297+ 1 , # output dim (only one site energy per atom)
298+ NZ, # number of categories = num species
299+ selector1)
300+
301+ et_pair = ETPairModel (rembed, readout)
302+
303+ return et_pair
304+ end
0 commit comments