Skip to content

fix: Metal GPU compatibility for Apple Silicon (jax-metal 0.1.1)#207

Open
elmariachi111 wants to merge 1 commit into
sokrypton:mainfrom
elmariachi111:apple-silicon-metal-jax-metal-compat
Open

fix: Metal GPU compatibility for Apple Silicon (jax-metal 0.1.1)#207
elmariachi111 wants to merge 1 commit into
sokrypton:mainfrom
elmariachi111:apple-silicon-metal-jax-metal-compat

Conversation

@elmariachi111
Copy link
Copy Markdown

Summary

Two patches that make ColabDesign's AlphaFold2 forward pass run on Apple Silicon Macs using jax-metal 0.1.1.

jnp.linalg.eigh and jnp.linalg.svd are not implemented for the Metal platform in jax-metal 0.1.1, blocking the AF2 structure module entirely on macOS ARM64.


Patch 1 — colabdesign/af/alphafold/model/quat_affine.py

jnp.linalg.eigh is called in rot_to_quat() to find the largest eigenvector of a 4×4 symmetric matrix (the Shepperd K matrix → quaternion conversion). On Metal this raises:

NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform METAL

Fix: Replace with a 50-iteration power iteration using jax.lax.fori_loop (only requires matmul + norm, both of which work on Metal). Validated to <1e-7 error vs numpy.linalg.eigh on 20 random rotation matrices.


Patch 2 — colabdesign/shared/protein.py

jnp.linalg.svd is called in _np_kabsch() for Kabsch structural alignment. On Metal, jnp.linalg.svd internally lowers to the eigh primitive and fails with the same error.

Fix: Add _metal_safe_svd(A) — SVD via power iteration on AᵀA with deflation (3 eigenvectors via power method + cross product), using only jax.lax.fori_loop + matmul + norm. Validated to <2e-4 error vs numpy.linalg.svd on random 3×3 matrices. The result is wrapped in jax.lax.stop_gradient (standard practice for Kabsch alignment, prevents Metal from attempting to differentiate through the fori_loop).


What this enables / what remains blocked

AF2 forward pass / structure prediction: Works on Metal. pLDDT / pTM confidence scoring: Works on Metal. Full binder hallucination (value_and_grad): Blocked by separate jax-metal compiler bug.


Test environment

  • macOS 26.3 ARM64 (Apple paravirtual GPU via Virtualization Framework)
  • Python 3.10
  • jax==0.5.0 / jaxlib==0.5.0 / jax-metal==0.1.1
  • Benchmark: AF2 inference 3–9 s/forward pass on 56–164 residue proteins

A companion PR to cytokineking/FreeBindCraft addresses the higher-level Metal detection fixes also needed for FreeBindCraft on Apple Silicon.

jnp.linalg.eigh is not implemented on the Metal platform in jax-metal 0.1.1,
blocking the AF2 structure module (quat_affine.py) on Apple Silicon Macs.
jnp.linalg.svd internally calls eigh and is also blocked.

Fix 1 (colabdesign/af/alphafold/model/quat_affine.py):
  Replace jnp.linalg.eigh in rot_to_quat() with power iteration (50 iter
  fori_loop). The dominant eigenvector of the 4x4 symmetric K matrix is found
  iteratively; validated to <1e-7 error vs numpy eigh on 20 random rotation
  matrices. Canonical sign convention applied so the largest-magnitude
  component is always positive.

Fix 2 (colabdesign/shared/protein.py):
  Add _metal_safe_svd(): a power iteration SVD on A^T A with eigenvalue
  deflation for the 2nd vector and cross product for the 3rd (valid for the
  3x3 Kabsch input). Validated to <2e-4 error vs numpy SVD on random 3x3
  matrices (60-iteration power method). Replaces jnp.linalg.svd in
  _np_kabsch() when use_jax=True and wraps the result in
  jax.lax.stop_gradient to prevent gradient flow through fori_loop (which
  causes a separate jax-metal compiler bug).

Tested on: macOS 26.3 ARM64, jax==0.5.0 / jaxlib==0.5.0 / jax-metal==0.1.1 / Python 3.10

Enables AF2 forward inference on Apple Silicon Metal. Note: full backprop
(value_and_grad + haiku RNG) is blocked by a separate unresolved jax-metal
compiler bug and is not addressed here.

References:
  https://github.com/cytokineking/FreeBindCraft (Apple Silicon porting notes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant