-
Notifications
You must be signed in to change notification settings - Fork 98
Description
I have been generating forward and reverse dG plots (example below) from a large number of U_klns in parallel and noticed that memory consumption balloons over time.
Example of forward and reverse dG plots (Figure 5 from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4420631/)
The issue is that when JAX jit's a function, the function is cached based on the input shape (https://stackoverflow.com/a/75642593 https://docs.jax.dev/en/latest/export/shape_poly.html). So as the shape of the U_klns change over time the number of cached functions increases, leading to large amounts (> GB) of memory being allocated and not freed.
Below is a comparison of the memory usage of MBAR with and without JAX
Test Code: https://github.com/badisa/pymbar/blob/experiment/jax-vs-numpy-memory-usage/pymbar/tests/test_mbar.py#L550-L568
MBAR Numpy Memory Consumption
MBAR Jax Memory Consumption
The performance of using JAX over Numpy is also significantly worse. The test linked to above takes 5 seconds with Numpy (https://github.com/choderalab/pymbar/actions/runs/17384102513/job/49347464100?pr=563#step:6:228) and 111 seconds with Jax (https://github.com/choderalab/pymbar/actions/runs/17384102513/job/49347464103?pr=563#step:6:222), which is more than a 22x slow down.
It is possible to use jax.clear_caches() to resolve the memory issue, but this would harm the performance of other jax code I have. The best solution would likely be to merge #509 or provide some way to disable jit in Pymbar (there is the force_no_jax variable, but not clear how to cleanly modify that at runtime).