Skip to content

Commit 140b39d

Browse files
committed
Merged.
2 parents dc45a8f + e723ea7 commit 140b39d

File tree

5 files changed

+246
-97
lines changed

5 files changed

+246
-97
lines changed

.github/workflows/verify_extension_build.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ jobs:
3333
3434
- name: Test extension build via import
3535
run: |
36+
<<<<<<< HEAD
3637
pytest tests/import_test.py -k test_import
3738

3839
- name: Test JAX extension build
3940
run: |
4041
pip install "jax[cuda12]"
4142
pip install -e ./openequivariance[jax]
42-
pip install -e ./openequivariance_extjax --no-build-isolation
43+
pip install -e ./openequivariance_extjax --no-build-isolation
44+
=======
45+
pytest \
46+
tests/import_test.py::test_extension_built \
47+
tests/import_test.py::test_torch_extension_built
48+
>>>>>>> main

openequivariance/__init__.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# ruff: noqa: F401
2+
import sys
3+
import torch
4+
import numpy as np
5+
from pathlib import Path
6+
from importlib.metadata import version
7+
8+
import openequivariance.extlib
9+
10+
from openequivariance.extlib import (
11+
LINKED_LIBPYTHON,
12+
LINKED_LIBPYTHON_ERROR,
13+
BUILT_EXTENSION,
14+
BUILT_EXTENSION_ERROR,
15+
TORCH_COMPILE,
16+
TORCH_COMPILE_ERROR,
17+
)
18+
19+
from openequivariance.implementations.e3nn_lite import (
20+
TPProblem,
21+
Irrep,
22+
Irreps,
23+
_MulIr,
24+
Instruction,
25+
)
26+
from openequivariance.implementations.TensorProduct import TensorProduct
27+
from openequivariance.implementations.convolution.TensorProductConv import (
28+
TensorProductConv,
29+
)
30+
from openequivariance.implementations.utils import torch_to_oeq_dtype
31+
32+
__version__ = None
33+
try:
34+
__version__ = version("openequivariance")
35+
except Exception as e:
36+
print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr)
37+
38+
39+
def _check_package_editable():
40+
import json
41+
from importlib.metadata import Distribution
42+
43+
direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json")
44+
return json.loads(direct_url).get("dir_info", {}).get("editable", False)
45+
46+
47+
_editable_install_output_path = Path(__file__).parent.parent / "outputs"
48+
49+
50+
def torch_ext_so_path():
51+
"""
52+
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
53+
from the PyTorch C++ Interface.
54+
"""
55+
return openequivariance.extlib.torch_module.__file__
56+
57+
58+
torch.serialization.add_safe_globals(
59+
[
60+
TensorProduct,
61+
TensorProductConv,
62+
TPProblem,
63+
Irrep,
64+
Irreps,
65+
_MulIr,
66+
Instruction,
67+
np.float32,
68+
np.float64,
69+
]
70+
)
71+
72+
__all__ = [
73+
"TPProblem",
74+
"Irreps",
75+
"TensorProduct",
76+
"TensorProductConv",
77+
"torch_to_oeq_dtype",
78+
"_check_package_editable",
79+
"torch_ext_so_path",
80+
]

openequivariance/openequivariance/_torch/extlib/__init__.py

Lines changed: 97 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
import sysconfig
66
from pathlib import Path
77

8-
global torch
98
import torch
109

1110
from openequivariance.benchmark.logging_utils import getLogger
1211

1312
oeq_root = str(Path(__file__).parent.parent.parent)
1413

15-
build_ext = True
16-
TORCH_COMPILE = True
17-
TORCH_VERSION_CUDA_OR_HIP = torch.version.cuda or torch.version.hip
18-
torch_module, generic_module = None, None
19-
postprocess_kernel = lambda kernel: kernel # noqa : E731
14+
BUILT_EXTENSION = False
15+
BUILT_EXTENSION_ERROR = None
16+
17+
TORCH_COMPILE = False
18+
TORCH_COMPILE_ERROR = None
2019

2120
LINKED_LIBPYTHON = False
2221
LINKED_LIBPYTHON_ERROR = None
22+
23+
torch_module, generic_module = None, None
24+
postprocess_kernel = lambda kernel: kernel # noqa : E731
25+
26+
2327
try:
2428
python_lib_dir = sysconfig.get_config_var("LIBDIR")
2529
major, minor = sys.version_info.major, sys.version_info.minor
@@ -33,112 +37,109 @@
3337
)
3438

3539
LINKED_LIBPYTHON = True
36-
3740
except Exception as e:
3841
LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}"
3942

40-
generic_module = None
41-
if not build_ext:
42-
import openequivariance._torch.extlib.generic_module
4343

44+
if BUILT_EXTENSION:
45+
import openequivariance._torch.extlib.generic_module
4446
generic_module = openequivariance._torch.extlib.generic_module
45-
46-
elif TORCH_VERSION_CUDA_OR_HIP:
47-
from torch.utils.cpp_extension import library_paths, include_paths
48-
49-
extra_cflags = ["-O3"]
50-
generic_sources = ["generic_module.cpp"]
51-
torch_sources = ["libtorch_tp_jit.cpp"]
52-
53-
include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])
54-
55-
if LINKED_LIBPYTHON:
56-
extra_link_args.pop()
57-
extra_link_args.extend(
58-
[
59-
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
60-
f"-L{python_lib_dir}",
61-
f"-l{python_lib_name}",
62-
],
63-
)
64-
65-
if torch.version.cuda:
66-
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
67-
68-
try:
69-
torch_libs, cuda_libs = library_paths("cuda")
47+
elif torch.version.cuda or torch.version.hip:
48+
try:
49+
from torch.utils.cpp_extension import library_paths, include_paths
50+
51+
extra_cflags = ["-O3"]
52+
generic_sources = ["generic_module.cpp"]
53+
torch_sources = ["libtorch_tp_jit.cpp"]
54+
55+
include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])
56+
57+
if LINKED_LIBPYTHON:
58+
extra_link_args.pop()
59+
extra_link_args.extend(
60+
[
61+
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
62+
f"-L{python_lib_dir}",
63+
f"-l{python_lib_name}",
64+
],
65+
)
66+
if torch.version.cuda:
67+
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
68+
69+
try:
70+
torch_libs, cuda_libs = library_paths("cuda")
71+
extra_link_args.append("-Wl,-rpath," + torch_libs)
72+
extra_link_args.append("-L" + cuda_libs)
73+
if os.path.exists(cuda_libs + "/stubs"):
74+
extra_link_args.append("-L" + cuda_libs + "/stubs")
75+
except Exception as e:
76+
getLogger().info(str(e))
77+
78+
extra_cflags.append("-DCUDA_BACKEND")
79+
elif torch.version.hip:
80+
extra_link_args.extend(["-lhiprtc"])
81+
torch_libs = library_paths("cuda")[0]
7082
extra_link_args.append("-Wl,-rpath," + torch_libs)
71-
extra_link_args.append("-L" + cuda_libs)
72-
if os.path.exists(cuda_libs + "/stubs"):
73-
extra_link_args.append("-L" + cuda_libs + "/stubs")
74-
except Exception as e:
75-
getLogger().info(str(e))
76-
77-
extra_cflags.append("-DCUDA_BACKEND")
78-
elif torch.version.hip:
79-
extra_link_args.extend(["-lhiprtc"])
80-
torch_libs = library_paths("cuda")[0]
81-
extra_link_args.append("-Wl,-rpath," + torch_libs)
82-
83-
def postprocess(kernel):
84-
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
85-
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
86-
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
87-
return kernel
88-
89-
postprocess_kernel = postprocess
90-
91-
extra_cflags.append("-DHIP_BACKEND")
92-
93-
generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
94-
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
95-
include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths(
96-
"cuda"
97-
)
98-
99-
torch_compile_exception = None
100-
with warnings.catch_warnings():
101-
warnings.simplefilter("ignore")
10283

103-
try:
104-
torch_module = torch.utils.cpp_extension.load(
105-
"libtorch_tp_jit",
106-
torch_sources,
84+
def postprocess(kernel):
85+
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
86+
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
87+
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
88+
return kernel
89+
90+
postprocess_kernel = postprocess
91+
92+
extra_cflags.append("-DHIP_BACKEND")
93+
94+
generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
95+
torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
96+
include_dirs = [
97+
oeq_root + "/extension/" + d for d in include_dirs
98+
] + include_paths("cuda")
99+
100+
with warnings.catch_warnings():
101+
warnings.simplefilter("ignore")
102+
103+
try:
104+
torch_module = torch.utils.cpp_extension.load(
105+
"libtorch_tp_jit",
106+
torch_sources,
107+
extra_cflags=extra_cflags,
108+
extra_include_paths=include_dirs,
109+
extra_ldflags=extra_link_args,
110+
)
111+
torch.ops.load_library(torch_module.__file__)
112+
TORCH_COMPILE = True
113+
except Exception as e:
114+
# If compiling torch fails (e.g. low gcc version), we should fall back to the
115+
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
116+
TORCH_COMPILE_ERROR = e
117+
118+
generic_module = torch.utils.cpp_extension.load(
119+
"generic_module",
120+
generic_sources,
107121
extra_cflags=extra_cflags,
108122
extra_include_paths=include_dirs,
109123
extra_ldflags=extra_link_args,
110124
)
111-
torch.ops.load_library(torch_module.__file__)
112-
except Exception as e:
113-
# If compiling torch fails (e.g. low gcc version), we should fall back to the
114-
# version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
115-
TORCH_COMPILE = False
116-
torch_compile_exception = e
117-
118-
generic_module = torch.utils.cpp_extension.load(
119-
"generic_module",
120-
generic_sources,
121-
extra_cflags=extra_cflags,
122-
extra_include_paths=include_dirs,
123-
extra_ldflags=extra_link_args,
124-
)
125-
if "generic_module" not in sys.modules:
126-
sys.modules["generic_module"] = generic_module
125+
if "generic_module" not in sys.modules:
126+
sys.modules["generic_module"] = generic_module
127127

128-
if not TORCH_COMPILE:
129-
warnings.warn(
130-
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
131-
+ f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
132-
)
128+
if not TORCH_COMPILE:
129+
warnings.warn(
130+
"Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
131+
+ f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}"
132+
)
133+
BUILT_EXTENSION = True
134+
except Exception as e:
135+
BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}"
133136
else:
134-
TORCH_COMPILE = False
137+
BUILT_EXTENSION_ERROR = "OpenEquivariance extension build not attempted"
135138

136139

137140
def _raise_import_error_helper(import_target: str):
138-
if not TORCH_VERSION_CUDA_OR_HIP:
139-
raise ImportError(
140-
f"Could not import {import_target}: OpenEquivariance's torch extension was not built because torch.version.cuda || torch.version.hip is false"
141-
)
141+
if not BUILT_EXTENSION:
142+
raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}")
142143

143144

144145
def torch_ext_so_path():

openequivariance/openequivariance/benchmark/problems.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,54 @@ def nequip_problems():
151151
]
152152

153153

154+
# https://github.com/atomicarchitects/nequix/blob/main/configs/nequix-mp-1.yml
155+
def nequix_problems():
156+
return [
157+
CTPP(
158+
"89x0e",
159+
"1x0e+1x1o+1x2e+1x3o",
160+
"89x0e+89x1o+89x2e+89x3o",
161+
"nequix-mp-1-first_layer",
162+
),
163+
CTPP(
164+
"128x0e+64x1o+32x2e+32x3o",
165+
"1x0e+1x1o+1x2e+1x3o",
166+
"128x0e+128x1o+128x2e+128x3o+64x1o+64x0e+64x2e+64x1o+64x3o+64x2e+32x2e+32x1o+32x3o+32x0e+32x2e+32x1o+32x3o+32x3o+32x2e+32x1o+32x3o+32x0e+32x2e",
167+
"nequix-mp-1-main_layers",
168+
),
169+
CTPP(
170+
"128x0e+64x1o+32x2e+32x3o",
171+
"1x0e+1x1o+1x2e+1x3o",
172+
"128x0e+64x0e+32x0e+32x0e",
173+
"nequix-mp-1-last_layer",
174+
),
175+
]
176+
177+
178+
# https://github.com/MDIL-SNU/SevenNet/tree/main/sevenn/pretrained_potentials/SevenNet_l3i5
179+
def seven_net_problems():
180+
return [
181+
CTPP(
182+
"128x0e",
183+
"1x0e+1x1e+1x2e+1x3e",
184+
"128x0e+128x1e+128x2e+128x3e",
185+
"SevenNet_l3i5-first-layer",
186+
),
187+
CTPP(
188+
"128x0e+64x1e+32x2e+32x3e",
189+
"1x0e+1x1e+1x2e+1x3e",
190+
"128x0e+64x0e+32x0e+32x0e+128x1e+64x1e+64x1e+64x1e+32x1e+32x1e+32x1e+32x1e+32x1e+128x2e+64x2e+64x2e+64x2e+32x2e+32x2e+32x2e+32x2e+32x2e+32x2e+32x2e+128x3e+64x3e+64x3e+32x3e+32x3e+32x3e+32x3e+32x3e+32x3e+32x3e",
191+
"SevenNet_l3i5-main-layers",
192+
),
193+
CTPP(
194+
"128x0e+64x1e+32x2e+32x3e",
195+
"1x0e+1x1e+1x2e+1x3e",
196+
"128x0e+64x0e+32x0e+32x0e",
197+
"SevenNet_l3i5-last-layer",
198+
),
199+
]
200+
201+
154202
def e3tools_problems():
155203
return [
156204
FCTPP(in1, in2, out, label=label, shared_weights=sw, internal_weights=iw)

0 commit comments

Comments
 (0)