Skip to content

Commit efa4d48

Browse files
committed
finalize tests
1 parent 79e86b4 commit efa4d48

File tree

2 files changed

+43
-20
lines changed

2 files changed

+43
-20
lines changed

src/et_models/convert.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,29 @@ function _convert_env_TENV(::Type{<: PolyEnvelope1sR}, envelopes)
230230
return ET.st_transform(f_env, refst)
231231
end
232232

233-
234233
function _eval_env_1sr(r, rcut, p)
235234
_1 = one(r)
236235
s = r / rcut
237236
return (s^(-p) - _1) * (_1 - s) * (s < _1)
238237
end
239238

239+
function _convert_pair_envelope(envelopes)
240+
TENV = typeof(envelopes[1])
241+
for env in envelopes
242+
@assert typeof(env) == TENV
243+
end
244+
env1 = envelopes[1]
245+
@assert env1 isa PolyEnvelope1sR
246+
for env in envelopes
247+
@assert env == env1
248+
end
249+
refst = ( rcut = env1.rcut, p = env1.p )
250+
f_env = ET.dp_transform( (x, st) -> _eval_env_1sr( norm(x.𝐫), st.rcut, st.p ),
251+
refst )
252+
return f_env
253+
end
254+
255+
240256

241257
function convertpair(model)
242258

@@ -263,9 +279,10 @@ function convertpair(model)
263279
rbasis_1 = ET.EmbedDP(dp_agnesi, polys, et_linl)
264280

265281
# 2: envelope
266-
_env_r = _convert_envelope(basis.envelopes)
267-
dp_envelope = ET.dp_transform( (x, st) -> _env_r.f( norm(x.𝐫), st ),
268-
_env_r.refstate )
282+
dp_envelope = _convert_pair_envelope(basis.envelopes)
283+
# _env_r = _convert_envelope(basis.envelopes)
284+
# dp_envelope = ET.dp_transform( (x, st) -> _env_r.f( norm(x.𝐫), st ),
285+
# _env_r.refstate )
269286

270287
# 3. combine into the radial basis
271288
rembed = ET.EdgeEmbed( EnvRBranchL(dp_envelope, rbasis_1) )

test/etmodels/test_etpair.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
2-
using TestEnv; TestEnv.activate();
3-
Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl"))
4-
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl"))
1+
# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
2+
# using TestEnv; TestEnv.activate();
3+
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl"))
4+
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl"))
55
# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles"))
66

77
##
@@ -137,6 +137,8 @@ println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data))
137137

138138
# turn off during CI -- need to sort out CI for GPU tests
139139

140+
#=
141+
140142
@info("Check GPU evaluation")
141143
using Metal
142144
dev = Metal.mtl
@@ -153,21 +155,25 @@ G_dev = dev(G_32)
153155
E1, st = et_pair(G_32, ps_32, st_32)
154156
E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev)
155157
E2 = Array(E2_dev)
158+
println_slim(@test E1 ≈ E2)
156159
160+
g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32)
161+
g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev)
162+
g2_edge = Array(g2_dev.edge_data)
163+
println_slim(@test all(g1.edge_data .≈ g2_edge))
157164
158-
g1 = ETM.site_grads(et_V0, G_32, ps_32, st_32)
159-
g2_dev = ETM.site_grads(et_V0, G_dev, ps_dev, st_dev)
160-
g2 = Array(g2_dev)
161-
println_slim(@test g1 == g2)
162-
163-
b1 = ETM.site_basis(et_V0, G_32, ps_32, st_32)
164-
b2_dev = ETM.site_basis(et_V0, G_dev, ps_dev, st_dev)
165+
b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32)
166+
b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev)
165167
b2 = Array(b2_dev)
166-
println_slim(@test b1 == b2)
168+
println_slim(@test b1 b2)
167169
168-
b1, ∂db1 = ETM.site_basis_jacobian(et_V0, G_32, ps_32, st_32)
169-
b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_V0, G_dev, ps_dev, st_dev)
170+
b1, ∂db1 = ETM.site_basis_jacobian(et_pair, G_32, ps_32, st_32)
171+
b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_pair, G_dev, ps_dev, st_dev)
170172
b2 = Array(b2_dev)
171173
∂db2 = Array(∂db2_dev)
172-
println_slim(@test b1 == b2)
173-
println_slim(@test ∂db1 == ∂db2)
174+
println_slim(@test b1 ≈ b2)
175+
jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2))
176+
@show maximum(jacerr)
177+
println_slim( @test maximum(jacerr) < 1e-4 )
178+
179+
=#

0 commit comments

Comments
 (0)