2525 )
2626
2727
28+ class HadamardMode (str , Enum ):
29+ """Hadamard transform mode for improved FP8 quantization quality."""
30+
31+ NONE = "NONE" # No Hadamard transform
32+ QKV = "QKV" # Apply Hadamard to Q, K, and V
33+
34+
2835class AttentionBackend (str , Enum ):
2936 """Backend kernel for computing attention."""
3037
@@ -60,6 +67,7 @@ def _check_backend_available(backend: AttentionBackend) -> None:
6067def apply_low_precision_attention (
6168 model : nn .Module ,
6269 backend : Optional [AttentionBackend ] = None ,
70+ hadamard : HadamardMode = HadamardMode .NONE ,
6371) -> nn .Module :
6472 """Apply low-precision attention to a model.
6573
@@ -71,6 +79,15 @@ def apply_low_precision_attention(
7179 for eager execution and sets a global pre-grad pass so that
7280 ``torch.compile`` will automatically fuse RoPE where detected.
7381
82+ Args:
83+ model: The model to apply low-precision attention to.
84+ backend: Backend to use. If None, auto-detected.
85+ hadamard: Hadamard transform mode. ``HadamardMode.QKV`` applies
86+ the Hadamard transform to Q, K, and V before FP8 quantization,
87+ spreading outliers across the head dimension for better
88+ dynamic range utilization. Requires D to be a power of 2
89+ and <= 256.
90+
7491 Example:
7592
7693 .. literalinclude:: ../../examples/prototype/low_precision_attention.py
@@ -93,6 +110,6 @@ def apply_low_precision_attention(
93110 _check_backend_available (backend )
94111
95112 if backend == AttentionBackend .FP8_FA3 :
96- return setup_fp8_backend (model , "FA3" )
113+ return setup_fp8_backend (model , "FA3" , hadamard = str ( hadamard ) )
97114
98115 raise ValueError (f"Unknown backend: { backend } " )
0 commit comments