Skip to content

Commit 4f0d1c2

Browse files
add hadamard option to low precision attention api
ghstack-source-id: a1c5c5d Pull-Request: #4194
1 parent 02105d4 commit 4f0d1c2

File tree

16 files changed

+1289
-59
lines changed

16 files changed

+1289
-59
lines changed

benchmarks/prototype/attention/benchmark_sdpa.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,34 @@
2727
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
2828
from torchao.quantization.utils import compute_error as compute_sqnr
2929

30-
BACKENDS = ["fa2", "fa3", "fa3_fp8"]
30+
BACKENDS = ["fa2", "fa3", "fa3_fp8", "fa3_fp8_hadamard"]
3131

3232
BACKEND_LABELS = {
3333
"fa2": "FA2 BF16",
3434
"fa3": "FA3 BF16",
3535
"fa3_fp8": "FA3 FP8",
36+
"fa3_fp8_hadamard": "FA3 FP8 Hadamard",
3637
}
3738

3839

3940
@contextmanager
4041
def _activate_backend(backend: str):
4142
"""Context manager that activates the appropriate flash attention impl."""
42-
if backend in ("fa3", "fa3_fp8"):
43+
if backend in ("fa3", "fa3_fp8", "fa3_fp8_hadamard"):
4344
activate_flash_attention_impl("FA3")
44-
else:
45-
# fa2 is the default, no activation needed
46-
pass
4745
try:
4846
yield
4947
finally:
50-
if backend in ("fa3", "fa3_fp8"):
48+
if backend in ("fa3", "fa3_fp8", "fa3_fp8_hadamard"):
5149
restore_flash_attention_impl()
5250

5351

5452
def _run_attention(backend: str, q, k, v, is_causal: bool):
5553
"""Run a single attention call for the given backend."""
5654
if backend == "fa3_fp8":
5755
return fp8_fa3_sdpa(q, k, v, is_causal=is_causal)
56+
elif backend == "fa3_fp8_hadamard":
57+
return fp8_fa3_sdpa(q, k, v, is_causal=is_causal, hadamard=True)
5858
else:
5959
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
6060
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)

benchmarks/prototype/attention/eval_flux_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from torchao.prototype.attention import (
3434
AttentionBackend,
35+
HadamardMode,
3536
apply_low_precision_attention,
3637
)
3738

@@ -43,6 +44,12 @@
4344
"fp8": True,
4445
"fp8_backend": AttentionBackend.FP8_FA3,
4546
},
47+
"fa3_fp8_hadamard": {
48+
"flash_impl": "FA3",
49+
"fp8": True,
50+
"fp8_backend": AttentionBackend.FP8_FA3,
51+
"hadamard": HadamardMode.QKV,
52+
},
4653
}
4754

4855
IMAGE_SIZE = (512, 512) # (width, height) - resize for consistent LPIPS
@@ -72,6 +79,7 @@ def setup_backend(
7279
pipe.transformer = apply_low_precision_attention(
7380
pipe.transformer,
7481
backend=cfg["fp8_backend"],
82+
hadamard=cfg.get("hadamard", HadamardMode.NONE),
7583
)
7684
if compile_flag:
7785
print(f"Compiling transformer with torch.compile ({backend_name})...")

benchmarks/prototype/attention/eval_llama3_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from torchao.prototype.attention import (
3333
AttentionBackend,
34+
HadamardMode,
3435
apply_low_precision_attention,
3536
)
3637
from torchao.prototype.attention.shared_utils.fusion_utils import (
@@ -57,6 +58,13 @@
5758
"fp8_backend": AttentionBackend.FP8_FA3,
5859
"label": "FA3 FP8",
5960
},
61+
"fa3_fp8_hadamard": {
62+
"flash_impl": "FA3",
63+
"fp8": True,
64+
"fp8_backend": AttentionBackend.FP8_FA3,
65+
"hadamard": HadamardMode.QKV,
66+
"label": "FA3 FP8 Hadamard",
67+
},
6068
}
6169

6270
RANDOM_SEED = 42
@@ -116,6 +124,7 @@ def setup_backend(orig_model, backend_name, compile_flag):
116124
model = apply_low_precision_attention(
117125
orig_model,
118126
backend=cfg["fp8_backend"],
127+
hadamard=cfg.get("hadamard", HadamardMode.NONE),
119128
)
120129
if compile_flag:
121130
print(f" Compiling model with torch.compile ({backend_name})...")

torchao/prototype/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313
from torchao.prototype.attention.api import (
1414
AttentionBackend,
15+
HadamardMode,
1516
apply_low_precision_attention,
1617
)
1718

1819
__all__ = [
1920
"AttentionBackend",
21+
"HadamardMode",
2022
"apply_low_precision_attention",
2123
]

torchao/prototype/attention/api.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
)
2626

2727

28+
class HadamardMode(str, Enum):
29+
"""Hadamard transform mode for improved FP8 quantization quality."""
30+
31+
NONE = "NONE" # No Hadamard transform
32+
QKV = "QKV" # Apply Hadamard to Q, K, and V
33+
34+
2835
class AttentionBackend(str, Enum):
2936
"""Backend kernel for computing attention."""
3037

@@ -60,6 +67,7 @@ def _check_backend_available(backend: AttentionBackend) -> None:
6067
def apply_low_precision_attention(
6168
model: nn.Module,
6269
backend: Optional[AttentionBackend] = None,
70+
hadamard: HadamardMode = HadamardMode.NONE,
6371
) -> nn.Module:
6472
"""Apply low-precision attention to a model.
6573
@@ -71,6 +79,15 @@ def apply_low_precision_attention(
7179
for eager execution and sets a global pre-grad pass so that
7280
``torch.compile`` will automatically fuse RoPE where detected.
7381
82+
Args:
83+
model: The model to apply low-precision attention to.
84+
backend: Backend to use. If None, auto-detected.
85+
hadamard: Hadamard transform mode. ``HadamardMode.QKV`` applies
86+
the Hadamard transform to Q, K, and V before FP8 quantization,
87+
spreading outliers across the head dimension for better
88+
dynamic range utilization. Requires D to be a power of 2
89+
and <= 256.
90+
7491
Example:
7592
7693
.. literalinclude:: ../../examples/prototype/low_precision_attention.py
@@ -93,6 +110,6 @@ def apply_low_precision_attention(
93110
_check_backend_available(backend)
94111

95112
if backend == AttentionBackend.FP8_FA3:
96-
return setup_fp8_backend(model, "FA3")
113+
return setup_fp8_backend(model, "FA3", hadamard=str(hadamard))
97114

98115
raise ValueError(f"Unknown backend: {backend}")

torchao/prototype/attention/quantization/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torchao.prototype.attention.quantization.triton_hadamard_qkv_quantization import (
8+
triton_fp8_hadamard_sdpa_quantize as _fp8_hadamard_sdpa_quantize,
9+
)
10+
from torchao.prototype.attention.quantization.triton_hadamard_rope_qkv_quantization import (
11+
triton_fp8_hadamard_rope_sdpa_quantize as _fp8_hadamard_rope_sdpa_quantize,
12+
)
13+
from torchao.prototype.attention.quantization.triton_hadamard_utils import (
14+
inverse_hadamard_transform as _inverse_hadamard_transform,
15+
)
716
from torchao.prototype.attention.quantization.triton_qkv_quantization import (
817
triton_fp8_sdpa_quantize as _fp8_sdpa_quantize,
918
)
@@ -14,4 +23,7 @@
1423
__all__ = [
1524
"_fp8_sdpa_quantize",
1625
"_fp8_rope_sdpa_quantize",
26+
"_fp8_hadamard_sdpa_quantize",
27+
"_fp8_hadamard_rope_sdpa_quantize",
28+
"_inverse_hadamard_transform",
1729
]

0 commit comments

Comments
 (0)