Skip to content

Commit cc7ec27

Browse files
committed
chore: final hardening and documentation polish for release
- Robust kernel path discovery using importlib.resources - Added platform guardrails and import-time checks - Added MHC_MLX_DISABLE_METAL escape hatch - Neutral README tone and reproducible benchmarks - Added CONTRIBUTING.md and SECURITY.md - Bumped version to 0.5.4
1 parent 1581b89 commit cc7ec27

File tree

9 files changed

+80
-16
lines changed

9 files changed

+80
-16
lines changed

CONTRIBUTING.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Contributing to mhc-mlx
2+
3+
We welcome contributions to improve speed, stability, and compatibility!
4+
5+
## Development Setup
6+
7+
1. Clone the repo:
8+
```bash
9+
git clone https://github.com/svdrecbd/mhc-mlx.git
10+
cd mhc-mlx
11+
```
12+
13+
2. Install dependencies:
14+
```bash
15+
pip install -e ".[dev,bench]"
16+
```
17+
18+
## Running Tests
19+
20+
Run the full suite:
21+
```bash
22+
python -m pytest
23+
```
24+
25+
Note: Some tests require Apple Silicon to execute the Metal paths.
26+
27+
## Benchmarking
28+
29+
Run the main benchmark:
30+
```bash
31+
PYTHONPATH=. python mhc_mlx/benchmark.py --mode latency
32+
```
33+
34+
## Auto-Tuning
35+
36+
If you've modified kernels, run the tuner:
37+
```bash
38+
PYTHONPATH=. python scripts/tune.py
39+
```

README.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**High-performance MLX implementation of Manifold-Constrained Hyper-Connections (mHC)** for Apple Silicon.
44

5-
mHC improves training stability and performance in deep architectures by constraining residual connections to the Birkhoff polytope (doubly stochastic matrices). This library provides optimized Metal kernels that achieve massive speedups over standard Python-based implementations.
5+
mHC improves training stability and performance in deep architectures by constraining residual connections to the Birkhoff polytope (doubly stochastic matrices). This library provides optimized Metal kernels that achieve significant speedups over standard baseline implementations.
66

77
**Original Paper:** [mHC: Manifold-Constrained Hyper-Connections](https://arxiv.org/abs/2512.24880) (DeepSeek-AI)
88

@@ -15,7 +15,7 @@ pip install mhc-mlx
1515
## Compatibility
1616
- **Hardware:** Apple Silicon (M1, M2, M3, M4).
1717
- **Software:** macOS, MLX >= 0.30.0.
18-
- **Fallback:** Automatically falls back to a compiled pure-MLX path if Metal kernels are unavailable.
18+
- **Fallback:** Automatically falls back to a compiled pure-MLX path on other platforms.
1919

2020
## Quick Start (30-second Demo)
2121

@@ -25,14 +25,13 @@ import mlx.nn as nn
2525
from mhc_mlx import MHCRewire
2626

2727
# 1. Take any standard MLX layer
28-
layer = nn.Linear(512, 512)
28+
layer = nn.Linear(2048, 2048)
2929

3030
# 2. Wrap it with mHC stability (automatically uses optimized Metal kernels)
31-
# This computes: H_post * (Linear(H_pre * x) + M * H_pre * x)
32-
model = MHCRewire(layer, dims=512, n=16)
31+
model = MHCRewire(layer, dims=2048, n=32)
3332

3433
# 3. Run forward pass
35-
x = mx.random.normal((1, 512))
34+
x = mx.random.normal((1, 2048))
3635
y = model(x)
3736
mx.eval(y)
3837

@@ -41,7 +40,7 @@ loss_fn = lambda m, x: mx.sum(m(x))
4140
grads = mx.grad(loss_fn)(model, x)
4241
mx.eval(grads)
4342

44-
print(f"Output shape: {y.shape}") # (1, 512)
43+
print(f"Output shape: {y.shape}") # (1, 2048)
4544
```
4645

4746
*Note: You can also use `from mlx_mhc import MHCRewire` for a community-friendly alias.*
@@ -52,9 +51,9 @@ print(f"Output shape: {y.shape}") # (1, 512)
5251

5352
### Comparative Benchmarks
5453

55-
Comparison with other standard MLX implementations of mHC ($C=512$):
54+
Comparison with a standard MLX implementation of mHC ($C=512$):
5655

57-
| Metric | mhc-mlx (Ours) | Standard Impl | Speedup |
56+
| Metric | mhc-mlx | Baseline Impl | Speedup |
5857
|---|---|---|---|
5958
| **Inference Latency** ($B=1$) | **392 us** | 1120 us | **2.86x** |
6059
| **Training Throughput** ($B=32$) | **105 us** | 866 us | **8.25x** |
@@ -63,7 +62,7 @@ Comparison with other standard MLX implementations of mHC ($C=512$):
6362

6463
| Approach | Architecture | Impact |
6564
|---|---|---|
66-
| **Standard** | Multiple kernel launches | High memory overhead, low GPU occupancy |
65+
| **Baseline** | Multiple kernel launches | High memory overhead, low GPU occupancy |
6766
| **mhc-mlx** | Fused Metal Kernels | Minimal memory round-trips, maximal bandwidth |
6867

6968
### Reproduce Benchmarks
@@ -86,4 +85,4 @@ mhc-mlx-info
8685
```
8786

8887
## License
89-
MIT
88+
MIT

SECURITY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Security Policy
2+
3+
## Reporting a Vulnerability
4+
5+
If you discover a security vulnerability within this project, please open an Issue on GitHub or contact the maintainer directly if sensitive.
6+
7+
We aim to address all security concerns promptly. As this is a research-focused implementation, the primary surface is the Metal kernel execution and model weight handling.

mhc_mlx/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
from .utils import residual_add_agg
44
from .patching import AutoPatcher
55
from importlib.metadata import version, PackageNotFoundError
6+
import platform
7+
import warnings
8+
9+
# Platform check
10+
if platform.system() != "Darwin":
11+
warnings.warn(
12+
"mhc-mlx is optimized for macOS + Apple Silicon (Metal). "
13+
"Falling back to pure-MLX path on this platform."
14+
)
615

716
try:
817
__version__ = version("mhc-mlx")

mhc_mlx/kernels/__init__.py

Whitespace-only changes.

mhc_mlx/layer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import mlx.core as mx
2424
import mlx.nn as nn
25+
import os
2526

2627
from .metal import (
2728
mhc_forward_fused_metal_autograd,
@@ -152,6 +153,8 @@ def mixing_matrix(self) -> mx.array:
152153
return mixing_matrix_from_logits(self.H_res_raw, iters=self.sinkhorn_iters, eps=self.eps)
153154

154155
def _should_use_metal(self, B: int, n: int, C: int) -> bool:
156+
if os.getenv("MHC_MLX_DISABLE_METAL", "0") == "1":
157+
return False
155158
if not self.use_metal:
156159
return False
157160
if not self.auto_dispatch:

mhc_mlx/metal.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import lru_cache
1616

1717
import mlx.core as mx
18+
from importlib import resources
1819

1920
# Config loading
2021
_TUNING_CONFIG = {}
@@ -34,7 +35,13 @@
3435
def _get_tuned_tpg(kernel_name: str, default: int) -> int:
3536
return _TUNING_CONFIG.get(kernel_name, default)
3637

37-
_KERNEL_DIR = os.path.join(os.path.dirname(__file__), "kernels")
38+
# Robust kernel path discovery using importlib.resources
39+
try:
40+
_KERNEL_DIR = str(resources.files("mhc_mlx") / "kernels")
41+
except Exception:
42+
# Fallback for older python or non-package installs
43+
_KERNEL_DIR = os.path.join(os.path.dirname(__file__), "kernels")
44+
3845
_STREAM_MIX_ADD_PATH = os.path.join(_KERNEL_DIR, "stream_mix_add.metal")
3946
_STREAM_MIX_ADD_RMS_PATH = os.path.join(_KERNEL_DIR, "stream_mix_add_rms.metal")
4047
_STREAM_MIX_ADD_RMS_FP16_PATH = os.path.join(_KERNEL_DIR, "stream_mix_add_rms_fp16.metal")

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "mhc-mlx"
7-
version = "0.5.3"
7+
version = "0.5.4"
88
description = "High-performance MLX implementation of Manifold-Constrained Hyper-Connections (mHC)"
99
requires-python = ">=3.10"
1010
readme = "README.md"
@@ -23,7 +23,7 @@ classifiers = [
2323
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2424
]
2525
dependencies = [
26-
"mlx>=0.30.0",
26+
"mlx>=0.30.0; platform_system == 'Darwin'",
2727
]
2828

2929
[project.urls]
@@ -47,7 +47,7 @@ mhc-mlx-info = "mhc_mlx.diagnostics:main"
4747
mhc-mlx-bench = "mhc_mlx.benchmark:main"
4848

4949
[tool.setuptools]
50-
packages = ["mhc_mlx", "mlx_mhc"]
50+
packages = ["mhc_mlx", "mhc_mlx.kernels", "mlx_mhc"]
5151
include-package-data = true
5252

5353
[tool.setuptools.package-data]

tests/test_packaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_kernels_present():
1515
# This works for python 3.10+
1616
try:
1717
kernel_files = [
18-
p.name for p in resources.files("mhc_mlx.kernels").iterdir()
18+
p.name for p in (resources.files("mhc_mlx") / "kernels").iterdir()
1919
if p.name.endswith(".metal")
2020
]
2121
except (AttributeError, TypeError):

0 commit comments

Comments
 (0)