Skip to content

Commit f491184

Browse files
mikemhenryIAlibaypre-commit-ci[bot]
authored
Disable JAX acceleration by default (#1694)
* Disable JAX acceleration by default * ruff fmt * add logging info * fix url * fix list * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add example of error message * added note about disabling jax acel by default --------- Co-authored-by: Irfan Alibay <IAlibay@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 78d0f90 commit f491184

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

docs/guide/troubleshooting.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,23 @@ If the necessary libraries for GPU acceleration are not installed and JAX detect
7373
7474
This warning does not mean that the *molecular dynamics* simulation will fall back to using the CPU.
7575
The simulation will still use the computing platform specified in the settings.
76+
77+
PYMBAR_DISABLE_JAX
78+
------------------
79+
80+
Due to a suspected memory leak in the JAX acceleration code in ``pymbar`` we disable JAX acceleration by default.
81+
This memory leak may result in the simulation crashing, wasting compute time.
82+
The error message may look like this:
83+
84+
.. code-block:: bash
85+
86+
LLVM compilation error: Cannot allocate memory
87+
LLVM ERROR: Unable to allocate section memory!
88+
89+
We have decided to disable JAX acceleration by default to prevent wasted compute.
90+
However, if you wish to use the JAX acceleration, you may set ``PYMBAR_DISABLE_JAX`` to ``TRUE`` (e.g. put ``export PYMBAR_DISABLE_JAX=FALSE`` in your submission script before running ``openfe quickrun``).
91+
For more information, see these issues on github:
92+
93+
- https://github.com/choderalab/pymbar/issues/564
94+
- https://github.com/OpenFreeEnergy/openfe/issues/1534
95+
- https://github.com/OpenFreeEnergy/openfe/issues/1654

news/jax-warning.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
* Emit a clarifying log message when a user gets a warning from JAX (#1585).
44
Fixes #1499.
55

6+
* Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (#1694).
7+
68
**Changed:**
79

810
* <news item>

openfe/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# Before we do anything else, we want to disable JAX
2+
# acceleration by default but if a user has set
3+
# PYMBAR_DISABLE_JAX to some value, we want to keep
4+
# it
5+
6+
import logging
7+
import os
8+
9+
logger = logging.getLogger(__name__)
10+
11+
if "PYMBAR_DISABLE_JAX" not in os.environ:
12+
logger.warn(
13+
"PYMBAR_DISABLE_JAX not set, setting to TRUE, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more details"
14+
)
15+
16+
# setdefault will only set PYMBAR_DISABLE_JAX if it is unset
17+
os.environ.setdefault("PYMBAR_DISABLE_JAX", "TRUE")
18+
19+
120
# We need to do this first so that we can set up our
221
# log control since some modules have warnings on import
322
from openfe.utils import logging_control

0 commit comments

Comments
 (0)