Skip to content

Commit c92db3c

Browse files
authored
feat: Enable Flash Attention for ONNX bindings (#1431)
* feat: add onnx rocm flash attention support Signed-off-by: Huamin Chen <hchen@redhat.com> * update paper Signed-off-by: Huamin Chen <hchen@redhat.com> * add bench and results Signed-off-by: Huamin Chen <hchen@redhat.com> * lint Signed-off-by: Huamin Chen <hchen@redhat.com> * lint Signed-off-by: Huamin Chen <hchen@redhat.com> --------- Signed-off-by: Huamin Chen <hchen@redhat.com>
1 parent 23578c8 commit c92db3c

File tree

27 files changed

+3471
-88
lines changed

27 files changed

+3471
-88
lines changed

.gitattributes

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@
1010
*.ico binary
1111
*.svg binary
1212
*.pdf binary
13+
14+
# Prebuilt CK Flash Attention shared library (ROCm/gfx942)
15+
onnx-binding/ort-ck-flash-attn/prebuilt/libort_ck_flash_attn.so* binary

.github/workflows/pre-commit.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- name: Set up Node
3737
uses: actions/setup-node@v5
3838
with:
39-
node-version: 23
39+
node-version: 20
4040

4141
- name: Set up Rust
4242
uses: dtolnay/rust-toolchain@stable
@@ -52,7 +52,7 @@ jobs:
5252
build-essential \
5353
pkg-config \
5454
shellcheck
55-
npm install -g markdownlint-cli
55+
npm install -g markdownlint-cli@0.43.0
5656
pip install --user yamllint codespell
5757
5858
- name: Set up golangci-lint

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ build/
3131
*.out
3232
extproc-server
3333

34+
# Allow prebuilt CK Flash Attention shared library
35+
!onnx-binding/ort-ck-flash-attn/prebuilt/libort_ck_flash_attn.so*
36+
3437
# IDE
3538
.idea/
3639
.vscode/

bench/cpu-vs-gpu/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# CPU vs GPU / SDPA vs FA Benchmarks
2+
3+
Measures signal extraction latency (jailbreak, PII, domain) for ONNX Runtime on AMD ROCm GPUs via Envoy ext_proc, using Prometheus histograms.
4+
5+
## Prerequisites
6+
7+
- AMD GPU with ROCm 7.0+ (`/dev/kfd`, `/dev/dri`)
8+
- Docker
9+
- `envoyproxy/envoy:v1.33-latest` image
10+
11+
## Setup
12+
13+
Build the router image (includes CK Flash Attention custom op):
14+
15+
```bash
16+
docker build -f tools/docker/Dockerfile.extproc-rocm -t semantic-router:rocm .
17+
```
18+
19+
Download models into `bench/cpu-vs-gpu/models/`:
20+
21+
```bash
22+
pip install huggingface_hub
23+
python3 -c "
24+
from huggingface_hub import snapshot_download
25+
for repo in [
26+
'mmbert32k-intent-classifier-merged',
27+
'mmbert32k-jailbreak-detector-merged',
28+
'mmbert32k-pii-detector-merged',
29+
]:
30+
snapshot_download(
31+
f'llm-semantic-router/{repo}',
32+
local_dir=f'bench/cpu-vs-gpu/models/{repo}-onnx',
33+
allow_patterns=['onnx/*', '*.json'],
34+
ignore_patterns=['*.safetensors', '*.bin', '*.pt'],
35+
)
36+
"
37+
```
38+
39+
Each model dir needs `model_sdpa_fp16.onnx` (for SDPA/CPU) and `model_fa_fp16.onnx` (for FA). Generate FA models with `onnx-binding/ort-ck-flash-attn/scripts/rewrite_graph.py` if not already present.
40+
41+
## Benchmarks
42+
43+
**CPU vs GPU** — compares ONNX CPU vs ROCm GPU across 500/2K/8K/16K token prompts:
44+
45+
```bash
46+
BENCH_IMAGE=semantic-router:rocm REQUESTS_PER_SIZE=10 ./bench-long-context.sh
47+
```
48+
49+
**SDPA vs Flash Attention** — compares standard attention vs CK Flash Attention on GPU:
50+
51+
```bash
52+
BENCH_IMAGE=semantic-router:rocm NUM_REQUESTS=20 ./bench-sdpa-vs-fa.sh
53+
```
54+
55+
Reports are written to `results/`.
56+
57+
## Scripts
58+
59+
| File | Description |
60+
|------|-------------|
61+
| `bench-long-context.sh` | CPU vs GPU, multi token-size, Prometheus metrics |
62+
| `bench-sdpa-vs-fa.sh` | SDPA vs FA on GPU, Prometheus metrics |
63+
| `config-bench.yaml` | Router config template (`USE_CPU_PLACEHOLDER` sed-replaced) |
64+
| `envoy-bench.yaml` | Envoy ext_proc proxy config |

0 commit comments

Comments
 (0)