Skip to content

Memory and Performance Issues with JAX solver #564

@badisa

Description

@badisa

I have been generating forward and reverse dG plots (example below) from a large number of U_klns in parallel and noticed that memory consumption balloons over time.

Image

Example of forward and reverse dG plots (Figure 5 from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4420631/)

The issue is that when JAX jit's a function, the function is cached based on the input shape (https://stackoverflow.com/a/75642593 https://docs.jax.dev/en/latest/export/shape_poly.html). So as the shape of the U_klns change over time the number of cached functions increases, leading to large amounts (> GB) of memory being allocated and not freed.

Below is a comparison of the memory usage of MBAR with and without JAX
Test Code: https://github.com/badisa/pymbar/blob/experiment/jax-vs-numpy-memory-usage/pymbar/tests/test_mbar.py#L550-L568

MBAR Numpy Memory Consumption

Image

MBAR Jax Memory Consumption

Image

The performance of using JAX over Numpy is also significantly worse. The test linked to above takes 5 seconds with Numpy (https://github.com/choderalab/pymbar/actions/runs/17384102513/job/49347464100?pr=563#step:6:228) and 111 seconds with Jax (https://github.com/choderalab/pymbar/actions/runs/17384102513/job/49347464103?pr=563#step:6:222), which is more than a 22x slow down.

It is possible to use jax.clear_caches() to resolve the memory issue, but this would harm the performance of other jax code I have. The best solution would likely be to merge #509 or provide some way to disable jit in Pymbar (there is the force_no_jax variable, but not clear how to cleanly modify that at runtime).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions