Skip to content

add hadamard option to low precision attention api#4194

Open
howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/39/basefrom
gh/howardzhang-cv/39/head
Open

add hadamard option to low precision attention api#4194
howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/39/basefrom
gh/howardzhang-cv/39/head

Conversation

@howardzhang-cv
Copy link
Copy Markdown
Contributor

@howardzhang-cv howardzhang-cv commented Mar 27, 2026

Stack from ghstack (oldest at bottom):

Summary

  • Added Hadamard on QKV tensor support for the low precision attention API, passed through from apply_low_precision_attention to select the hadamard fused kernel
  • Added new kernels (triton_hadamard_qkv_quantization.py and triton_hadamard_rope_qkv_quantization.py) in triton for fused hadamard and QKV quantization (with rope fusion as well)
    • Because of the way per-head quantization works with the sequence chunks, we need to store the hadamard outputs in a temp buffer, which eliminates some of the benefit of the fusion. However, it still saves one global read of the QKV tensors, which experiments show still benefits runtime quite a bit, so the fusion is still much better than just running hadamard separately.
  • Added hadamard option to the benchmarks
  • Replaced some duplicate compute_num_chunks code

Results

Single Attention Layer

image For a single attention layer, it got slower (from 1.36x speedup to 1.29x speedup on the highest sequence length). The SQNR does not really improve because we're testing with random tensors, which are pretty uniform already (Hadamard is intended to spread out intensity for better quantization accuracy)

LLaMA3 Prefill

image Perplexity used to go from 7.54 -> 7.62. Now it is noticeably better, going from 7.54 -> 7.57. The speedup dropped from 1.23x to 1.15x at the highest sequence length.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4194

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 8 Unrelated Failures

As of commit 3b80af2 with merge base f11eff8 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2026
howardzhang-cv added a commit that referenced this pull request Mar 27, 2026
@howardzhang-cv howardzhang-cv added the module: inference quantize_ api inference flow label Mar 27, 2026
[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Mar 27, 2026
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 30, 2026

@claude /pr-review detailed

Please do a deep review focused on subtle correctness bugs and regressions that are easy for human reviewers to miss. Prioritize findings over summary.

Specifically, look for:

incorrect edge-case behavior
missing or weak test coverage
backward-compatibility risks
silent behavioral changes
dtype/device/shape/stride/autograd/dispatch issues
aliasing, mutation, lifetime, and ownership bugs
concurrency or thread-safety problems
error-path bugs and partial-state failures
performance traps that come from correctness fixes or hidden extra work
Ignore style, formatting, and nits unless they hide a real bug. Be skeptical of code that “looks fine” and focus on failure modes, invariants, and assumptions not enforced by
tests.

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Claude finished @drisspg's task in 0s —— View job


I'll analyze this and get back to you.

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 30, 2026

@claude
Please do a deep review focused on subtle correctness bugs and regressions that are easy for human reviewers to miss. Prioritize findings over summary.

Specifically, look for:

incorrect edge-case behavior
missing or weak test coverage
backward-compatibility risks
silent behavioral changes
dtype/device/shape/stride/autograd/dispatch issues
aliasing, mutation, lifetime, and ownership bugs
concurrency or thread-safety problems
error-path bugs and partial-state failures
performance traps that come from correctness fixes or hidden extra work
Ignore style, formatting, and nits unless they hide a real bug. Be skeptical of code that “looks fine” and focus on failure modes, invariants, and assumptions not enforced by
tests.

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Claude finished @drisspg's task in 4m 27s —— View job


Deep Correctness Review — PR #4194

Reviewed all 16 changed files focusing on subtle correctness bugs, edge cases, and regressions.


Bug: hadamard=True (bool) passed where hadamard: str expected

Severity: Medium — silent wrong-type argument, works by accident

In benchmarks/prototype/attention/benchmark_sdpa.py:57:

return fp8_fa3_sdpa(q, k, v, is_causal=is_causal, hadamard=True)

The _fp8_sdpa function signature is hadamard: str = "NONE" and checks hadamard != "NONE". This accidentally works because True != "NONE" is True, so the hadamard path is taken. But the hadamard parameter is typed as str everywhere — in the custom op registration (custom_ops.py:88), the wrapper (wrapper.py:57), and the API enum. Passing True (a bool) through torch.library.custom_op which declares hadamard: str will likely fail under torch.compile because the custom op schema expects a string.

Should be: hadamard="QKV".

Fix this →


Bug: kv_temp buffer reuse between K and V kernel launches is unsafe

Severity: High — potential data corruption on some GPUs / stream configurations

In triton_hadamard_qkv_quantization.py:290-291:

# kv_temp can be reused: K's phase1 completed before V's launches
hadamard_single_phase1_kernel[kv_grid](v, v_had, kv_temp, ...)

And identically in triton_hadamard_rope_qkv_quantization.py:458-462.

The comment claims K's phase1 completes before V's launches, but Triton kernel launches are asynchronous on the CUDA default stream. Two consecutive Triton kernel launches on the same stream will be serialized, so this is safe on the default stream. However, if these functions are ever called under a CUDA stream context or with graph capture, the assumption breaks. The same kv_temp buffer is being read/written by two different grid launches with the same dimensions.

This is safe in the current code path but fragile. A comment noting "safe only because both launches are on the same CUDA stream and Triton launches serialize" would be appropriate. Alternatively, using separate temp buffers (at the cost of memory) would eliminate the concern.


Issue: tl.debug_barrier() used for correctness-critical synchronization

Severity: High — debug_barrier may be stripped in release/optimized Triton builds

In triton_hadamard_utils.py:82-84:

tl.store(temp_ptr + temp_base + d_idx, x)
tl.debug_barrier()
x_partner = tl.load(temp_ptr + temp_base + partner_d)
tl.debug_barrier()

The Hadamard butterfly relies on tl.debug_barrier() to synchronize threads within a block before cross-thread loads. tl.debug_barrier() is explicitly a debug tool — its documentation states it may be removed or behave differently in optimized builds. The correct synchronization primitive for production code is typically tl.barrier() (or arranging the computation to not need cross-thread communication via global memory).

If Triton ever optimizes away tl.debug_barrier(), the Hadamard butterfly will silently produce incorrect results (reading stale data from the temp buffer). This is a correctness-critical data dependency, not a debugging aid.

Recommendation: Verify whether tl.debug_barrier() is guaranteed to work as a synchronization primitive in production Triton, or switch to a shared-memory + proper barrier approach.


Issue: Hadamard butterfly uses global memory temp buffer instead of shared memory

Severity: Medium — performance and correctness coupling

The _hadamard_butterfly_stage function uses a global memory buffer (temp_ptr) for inter-thread communication. Each stage of the butterfly requires D threads to store, synchronize, load partner values, and synchronize again. This means 2 * LOG2_D global memory round-trips per element.

For D=128 (LOG2_D=7), that's 14 global memory barriers per element. Shared memory with __syncthreads would be far more appropriate and is the standard approach for intra-block communication in GPU kernels.

Beyond performance, the global memory approach creates a correctness coupling: the temp buffer must be uniquely indexed per (batch, head, chunk) triple to avoid aliasing between blocks. The current code handles this via temp_base indexing, but it's more error-prone than shared memory.


Issue: chunk_size loop iterates over full chunk_size even for the last chunk

Severity: Low — masked correctly but wasteful

In all phase1 kernels (e.g., hadamard_single_phase1_kernel, line 98):

for s_offset in range(chunk_size):
    s_idx = s_start + s_offset
    s_mask = s_idx < S

The loop always runs chunk_size iterations. For the last chunk, many iterations may have s_mask=False. The Hadamard butterfly still runs (including all barrier + global memory store/load stages), just with zero input. For the Hadamard path with per-element looping (not BLOCK_SIZE chunked), this means the last chunk can waste significant compute on the butterfly transform of zeros.

The existing non-hadamard kernels use chunk_elements = (s_end - s_start) * D to avoid this. The hadamard kernels can't easily do this because of the D-threaded structure, but it's worth noting the overhead.


Missing test coverage

Severity: High — no tests for any new functionality

There are zero tests for the new hadamard code paths:

  1. No unit tests for triton_fp8_hadamard_sdpa_quantize — the Hadamard butterfly correctness, quantization accuracy, and GQA handling are untested
  2. No unit tests for triton_fp8_hadamard_rope_sdpa_quantize — RoPE + Hadamard fusion is untested
  3. No unit tests for inverse_hadamard_transform — the self-inverse property is not verified
  4. No end-to-end tests with HadamardMode.QKV — the API integration path through apply_low_precision_attention(..., hadamard=HadamardMode.QKV) is untested
  5. No tests for the compile pathtorch.compile with hadamard through the custom op is untested

At minimum, the existing test_sdpa_accuracy and test_rope_sdpa_accuracy should be parameterized with hadamard=True/False, and the test_monkey_patch_model should test HadamardMode.QKV.


_compute_num_chunks refactoring — subtle behavioral change

Severity: Low — correctness maintained but chunking may change

The old _compute_num_chunks in triton_rope_qkv_quantization.py extracted B, _, H, _ = tensor.shape from [B, S, H, D] layout. The new unified version in triton_hadamard_utils.py:37 takes (device, B, H, S) scalars. The callers in triton_rope_qkv_quantization.py:559 now pass _compute_num_chunks(q.device, B, H_q, S).

This is functionally equivalent, but the interface change means any external caller of the old _compute_num_chunks (unlikely but not impossible) would break. Since this is prototype code, this is acceptable.


api.py passes str(hadamard) — correct but fragile

Severity: Low

In api.py:113:

return setup_fp8_backend(model, "FA3", hadamard=str(hadamard))

Since HadamardMode inherits from str, str(HadamardMode.QKV) produces "HadamardMode.QKV", not "QKV". But hadamard != "NONE" is the check, so "HadamardMode.QKV" != "NONE" is True, which correctly activates hadamard.

Wait — let me re-check. HadamardMode(str, Enum) with value "QKV". str(HadamardMode.QKV) gives "HadamardMode.QKV" in Python < 3.11, but .value gives "QKV". Since HadamardMode extends str, the enum member itself is a string with value "QKV", so str(HadamardMode.QKV) should give "QKV" because str.__str__ returns the value for str subclasses.

Actually, for class HadamardMode(str, Enum), str(HadamardMode.QKV) returns "HadamardMode.QKV" in standard Python. But HadamardMode.QKV == "QKV" is True because str comparison uses the value. So str(hadamard) would produce "HadamardMode.QKV", which != "NONE", so hadamard is enabled.

This works but is fragile — if someone later adds a check like hadamard == "QKV" instead of hadamard != "NONE", it would break. Using hadamard.value or just passing the enum directly would be cleaner.


Summary

Finding Severity Type
hadamard=True (bool) in benchmark Medium Type bug
tl.debug_barrier() for production sync High Correctness risk
Zero test coverage for hadamard paths High Testing gap
kv_temp buffer reuse assumption Medium Fragile correctness
Global memory for intra-block shuffle Medium Performance/correctness
Last-chunk loop overhead Low Performance
_compute_num_chunks interface change Low Refactoring
str(HadamardMode.QKV) fragility Low API design

The core Hadamard butterfly math is correct. The main concerns are the reliance on tl.debug_barrier() for production synchronization, the complete absence of tests, and the type mismatch in the benchmark.


[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Mar 30, 2026
@howardzhang-cv howardzhang-cv requested a review from drisspg March 30, 2026 22:00
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

@drisspg Can you take another look? Just addressed some of the claude review comments:

  1. I fixed the benchmark_sdpa bug.
  2. Added a comment for the kv_temp. I think it's better to keep it as the same for now, considering this kernel is only called internally and will always be used in a stream and called in a way that there won't be conflict, but I did add a comment acknowledging potential issues.
  3. Added some tests for hadamard, was going to do this in a separate PR after implementing Hadamard mode for V-only, but doing it here makes sense too.
  4. Changed from str(hadamard) to hadamard.value instead to avoid the pytorch 2.11 issue.

A lot of the other comments I think aren't really applicable in my opinion or are wrong, but happy to look into it more if you think I should.

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Mar 30, 2026
sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we expect/check sqnr to be higher w/ hadamard?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests here were more just making sure Hadamard was calculated correctly (with inverse and stuff). Because it's uniformly sampled, Hadamard doesn't actually improve SQNR. It only improves when there's outliers. I just added a separate set of tests to check that Hadamard improves accuracy when the inputs have outliers.

try:
with torch.no_grad():
out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False)
out_fp8 = fp8_fa3_rope_sdpa(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants