Skip to content

Commit 3dda80c

Browse files
JVP Support (#181)
* Added a JSN parser. * Debugging in progress. * Bound the double backward pass. * Ready to try conv next. * Adding the convolution primitives. * Convolution working. * JITted. * Tests JIT also. * Third derivative is failing. * Fixed failing third derivative. * Fixed up some more stuff. * Added VJP version for debugging. * Fixed zero buffers. * Zero'd some more buffers. * More things are working. * Fixed things up. * Precommit. * Merge changes.
1 parent 3c2dd77 commit 3dda80c

File tree

15 files changed

+2271
-425
lines changed

15 files changed

+2271
-425
lines changed

CHANGELOG.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
## Latest Changes
22

3+
### v0.5.4 (2025-02-01)
4+
Improvements to JAX frontend.
5+
6+
**Added**:
7+
- Jacobian Vector Products (JVP)
8+
for both `TensorProduct` and `TensorProductConv` via custom primitives, in addition to VJP.
9+
- Arbitrary higher-order derivatives in JAX.
10+
- JAX JIT support; in particular, support for
11+
Phonon Fine Tuning in [Nequix](https://github.com/atomicarchitects/nequix).
12+
13+
**Fixed**:
14+
- Zero'd all output buffers in the backwards and double-backwards implementations of convolution
15+
before calling kernels.
16+
17+
### v0.5.1-0.5.3 (2025-02-01)
18+
Minor bugfixes related to packaging and JAX.
19+
320
### v0.5.0 (2025-12-25)
421
JAX support is now available in
522
OpenEquivariance for BOTH NVIDIA and

openequivariance/openequivariance/core/ComputationSchedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __init__(self, num_blocks, num_threads, warp_size, smem):
320320
self.num_blocks = num_blocks
321321
self.num_threads = num_threads
322322
self.warp_size = warp_size
323-
self.smem = smem
323+
self.smem = int(smem)
324324

325325

326326
class ComputationSchedule:

openequivariance/openequivariance/core/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import json
99
import tempfile
10-
import hashlib
1110

1211
from enum import IntEnum
1312

@@ -200,13 +199,3 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
200199
time_millis[i] = kernel_time
201200

202201
return time_millis
203-
204-
205-
def hash_attributes(attrs):
206-
m = hashlib.sha256()
207-
208-
for key in sorted(attrs.keys()):
209-
m.update(attrs[key].__repr__().encode("utf-8"))
210-
211-
hash = int(m.hexdigest()[:16], 16) >> 1
212-
attrs["hash"] = hash

0 commit comments

Comments
 (0)