Skip to content

Commit 72122ce

Browse files
Vmap + pmap support for convolution primitives. (#182)
* Making progress with vmap. * More progress on vmap. * Making progress with a test. * Finished writing tests. * More testing. * Updated docs. * Fixed the stream issues.
1 parent 3dda80c commit 72122ce

File tree

10 files changed

+287
-59
lines changed

10 files changed

+287
-59
lines changed

docs/tests_and_benchmarks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ To set up an editable install and run our tests, use the following code:
4242
pytest --jax tests/example_test.py
4343
pytest --jax tests/batch_test.py
4444
pytest --jax tests/conv_test.py
45+
pytest --jax tests/vmap_test.py
4546
4647
Browse the ``tests`` directory to run specific components.
4748

openequivariance/openequivariance/core/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tempfile
1010

1111
from enum import IntEnum
12+
import hashlib
1213

1314

1415
class DTypeEnum(IntEnum):
@@ -199,3 +200,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
199200
time_millis[i] = kernel_time
200201

201202
return time_millis
203+
204+
205+
def hash_str_64(s: str) -> int:
206+
return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big")

openequivariance/openequivariance/extension/convolution.hpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ template<typename JIT_IMPL>
1515
class __attribute__ ((visibility ("default"))) JITConvImpl {
1616
public:
1717
JIT_IMPL jit;
18-
KernelLaunchConfig forward_config;
19-
KernelLaunchConfig backward_config;
20-
KernelLaunchConfig double_backward_config;
18+
19+
KernelLaunchConfig forward_config_ref;
20+
KernelLaunchConfig backward_config_ref;
21+
KernelLaunchConfig double_backward_config_ref;
2122
int opt_level;
2223

2324
JITConvImpl(
@@ -27,25 +28,25 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
2728
KernelLaunchConfig double_backward_config_i,
2829
int opt_level_i) :
2930
jit(jit_kernel),
30-
forward_config(forward_config_i),
31-
backward_config(backward_config_i),
32-
double_backward_config(double_backward_config_i),
31+
forward_config_ref(forward_config_i),
32+
backward_config_ref(backward_config_i),
33+
double_backward_config_ref(double_backward_config_i),
3334
opt_level(opt_level_i) {
3435

3536
vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"};
3637
jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);
3738

38-
if(forward_config.smem > 0) {
39-
jit.set_max_smem(0, forward_config.smem);
40-
jit.set_max_smem(4, forward_config.smem);
39+
if(forward_config_ref.smem > 0) {
40+
jit.set_max_smem(0, forward_config_ref.smem);
41+
jit.set_max_smem(4, forward_config_ref.smem);
4142
}
4243

43-
if(backward_config.smem > 0) {
44-
jit.set_max_smem(1, backward_config.smem);
44+
if(backward_config_ref.smem > 0) {
45+
jit.set_max_smem(1, backward_config_ref.smem);
4546
}
4647

47-
if(double_backward_config.smem > 0) {
48-
jit.set_max_smem(5, double_backward_config.smem);
48+
if(double_backward_config_ref.smem > 0) {
49+
jit.set_max_smem(5, double_backward_config_ref.smem);
4950
}
5051
}
5152

@@ -89,16 +90,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
8990
ConvData conv_data = {rows, cols, nnz, node_count};
9091

9192
void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace};
92-
forward_config.hStream = stream;
93-
jit.execute(0, args, forward_config);
93+
jit.execute(0, args, with_stream(forward_config_ref, stream));
9494

9595
if(reinterpret_cast<uint64_t>(workspace) != 0) {
9696
void *fixup_args[] = {&workspace, &L3_out};
9797

98-
KernelLaunchConfig fixup_config;
99-
fixup_config.num_blocks = forward_config.num_blocks;
100-
fixup_config.num_threads = forward_config.num_threads;
101-
fixup_config.smem = 0;
98+
KernelLaunchConfig fixup_config(
99+
forward_config_ref.num_blocks,
100+
forward_config_ref.num_threads,
101+
0
102+
);
102103
fixup_config.hStream = stream;
103104

104105
jit.execute(2, fixup_args, fixup_config);
@@ -118,16 +119,17 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
118119

119120
ConvData conv_data = {rows, cols, nnz, node_count};
120121
void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm};
121-
backward_config.hStream = stream;
122-
jit.execute(1, args, backward_config);
122+
jit.execute(1, args, with_stream(backward_config_ref, stream));
123123

124124
if(reinterpret_cast<uint64_t>(workspace) != 0) {
125125
void *fixup_args[] = {&workspace, &L1_grad};
126126

127-
KernelLaunchConfig fixup_config;
128-
fixup_config.num_blocks = backward_config.num_blocks;
129-
fixup_config.num_threads = backward_config.num_threads;
130-
fixup_config.smem = 0; fixup_config.hStream = stream;
127+
KernelLaunchConfig fixup_config(
128+
backward_config_ref.num_blocks,
129+
backward_config_ref.num_threads,
130+
0
131+
);
132+
fixup_config.hStream = stream;
131133

132134
jit.execute(3, fixup_args, fixup_config);
133135
}
@@ -147,24 +149,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
147149
&L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
148150
&L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
149151
};
150-
double_backward_config.hStream = stream;
151-
jit.execute(4, args, forward_config);
152+
153+
jit.execute(4, args, with_stream(forward_config_ref, stream));
152154
if(reinterpret_cast<uint64_t>(wspace) != 0) {
153155
void *fixup_args[] = {&wspace, &L3_dgrad};
154-
KernelLaunchConfig fixup_config;
155-
fixup_config.num_blocks = forward_config.num_blocks;
156-
fixup_config.num_threads = forward_config.num_threads;
157-
fixup_config.smem = 0; fixup_config.hStream = stream;
156+
KernelLaunchConfig fixup_config(
157+
forward_config_ref.num_blocks,
158+
forward_config_ref.num_threads,
159+
0
160+
);
161+
fixup_config.hStream = stream;
158162
jit.execute(2, fixup_args, fixup_config);
159163
}
160164

161-
jit.execute(5, args, double_backward_config);
165+
jit.execute(5, args, with_stream(double_backward_config_ref, stream));
162166
if(reinterpret_cast<uint64_t>(wspace) != 0) {
163167
void *fixup_args[] = {&wspace, &L1_grad};
164-
KernelLaunchConfig fixup_config;
165-
fixup_config.num_blocks = double_backward_config.num_blocks;
166-
fixup_config.num_threads = double_backward_config.num_threads;
167-
fixup_config.smem = 0; fixup_config.hStream = stream;
168+
KernelLaunchConfig fixup_config(
169+
double_backward_config_ref.num_blocks,
170+
double_backward_config_ref.num_threads,
171+
0
172+
);
173+
fixup_config.hStream = stream;
168174
jit.execute(6, fixup_args, fixup_config);
169175
}
170176
}

openequivariance/openequivariance/extension/tensorproducts.hpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ template<typename JIT_IMPL>
1010
class __attribute__ ((visibility ("default"))) JITTPImpl {
1111
public:
1212
JIT_IMPL jit;
13-
KernelLaunchConfig forward_config, backward_config, double_backward_config;
13+
14+
// Configs are suffixed with _ref because they
15+
// need to be copied and modified with the stream. In-place
16+
// modification not possible due to concurrency requirements.
17+
KernelLaunchConfig forward_config_ref, backward_config_ref, double_backward_config_ref;
1418
int opt_level;
1519

1620
JITTPImpl(
@@ -20,25 +24,25 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
2024
KernelLaunchConfig double_backward_config_i,
2125
int opt_level_i) :
2226
jit(jit_kernel),
23-
forward_config(forward_config_i),
24-
backward_config(backward_config_i),
25-
double_backward_config(double_backward_config_i),
27+
forward_config_ref(forward_config_i),
28+
backward_config_ref(backward_config_i),
29+
double_backward_config_ref(double_backward_config_i),
2630
opt_level(opt_level_i) {
2731

2832
vector<string> kernels = {"forward", "backward", "double_backward_A", "double_backward_B"};
2933
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
3034

31-
if(forward_config.smem > 0) {
32-
jit.set_max_smem(0, forward_config.smem);
33-
jit.set_max_smem(2, forward_config.smem);
35+
if(forward_config_ref.smem > 0) {
36+
jit.set_max_smem(0, forward_config_ref.smem);
37+
jit.set_max_smem(2, forward_config_ref.smem);
3438
}
3539

36-
if(backward_config.smem > 0) {
37-
jit.set_max_smem(1, backward_config.smem);
40+
if(backward_config_ref.smem > 0) {
41+
jit.set_max_smem(1, backward_config_ref.smem);
3842

3943
}
40-
if(double_backward_config.smem > 0) {
41-
jit.set_max_smem(3, double_backward_config.smem);
44+
if(double_backward_config_ref.smem > 0) {
45+
jit.set_max_smem(3, double_backward_config_ref.smem);
4246
}
4347
}
4448

@@ -77,8 +81,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
7781
Stream stream) {
7882

7983
void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights};
80-
forward_config.hStream = stream;
81-
jit.execute(0, args, forward_config);
84+
jit.execute(0, args, with_stream(forward_config_ref, stream));
8285
}
8386

8487
void backward(
@@ -88,8 +91,7 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
8891
void* weight, void* weight_grad,
8992
void* L3_grad, Stream stream) {
9093
void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad};
91-
backward_config.hStream = stream;
92-
jit.execute(1, args, backward_config);
94+
jit.execute(1, args, with_stream(backward_config_ref, stream));
9395
}
9496

9597
void double_backward(
@@ -102,9 +104,9 @@ class __attribute__ ((visibility ("default"))) JITTPImpl {
102104
&num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
103105
&L1_grad, &L2_grad, &W_grad, &L3_dgrad
104106
};
105-
double_backward_config.hStream = stream;
106-
jit.execute(2, args, forward_config);
107-
jit.execute(3, args, double_backward_config);
107+
double_backward_config_ref.hStream = stream;
108+
jit.execute(2, args, with_stream(forward_config_ref, stream));
109+
jit.execute(3, args, with_stream(double_backward_config_ref, stream));
108110
}
109111

110112
~JITTPImpl() = default;

openequivariance/openequivariance/extension/util/backend_cuda.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,4 +358,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
358358
}
359359
NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
360360
}
361-
};
361+
};
362+
363+
KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) {
364+
KernelLaunchConfig new_config = config;
365+
new_config.hStream = stream;
366+
return new_config;
367+
}

openequivariance/openequivariance/extension/util/backend_hip.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,10 @@ class __attribute__((visibility("default"))) HIPJITKernel {
310310
~HIPJITKernel() {
311311
HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog));
312312
}
313-
};
313+
};
314+
315+
KernelLaunchConfig with_stream(const KernelLaunchConfig& config, Stream stream) {
316+
KernelLaunchConfig new_config = config;
317+
new_config.hStream = stream;
318+
return new_config;
319+
}

openequivariance/openequivariance/jax/TensorProduct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from openequivariance.jax import extlib
44
from openequivariance.core.e3nn_lite import TPProblem
55
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
6+
from openequivariance.core.utils import hash_str_64
67
from openequivariance.jax.utils import reorder_jax
78
from openequivariance.jax.jvp.tp_prim import tp_fwd_p
89
import json
@@ -30,7 +31,7 @@ def __init__(self, problem: TPProblem):
3031
"kernel_prop": self.kernelProp,
3132
}
3233
)
33-
self.hash = self.kernel.__hash__()
34+
self.hash = hash_str_64(self.kernel)
3435

3536
self.weight_numel = problem.weight_numel
3637
self.L3_dim = self.config.irreps_out.dim

openequivariance/openequivariance/jax/TensorProductConv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Optional
66
from openequivariance.jax import extlib
77

8+
9+
from openequivariance.core.utils import hash_str_64
810
from openequivariance.core.e3nn_lite import TPProblem
911
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
1012
from openequivariance.jax.utils import reorder_jax
@@ -60,7 +62,7 @@ def __init__(
6062
"kernel_prop": self.kernel_prop,
6163
}
6264
)
63-
self.hash = self.kernel.__hash__()
65+
self.hash = hash_str_64(self.kernel)
6466

6567
self.weight_numel = config.weight_numel
6668
self.L3_dim = self.config.irreps_out.dim

0 commit comments

Comments
 (0)