add hadamard option to low precision attention api#4194
add hadamard option to low precision attention api#4194howardzhang-cv wants to merge 5 commits intogh/howardzhang-cv/39/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4194
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 8 Unrelated FailuresAs of commit 3b80af2 with merge base f11eff8 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude /pr-review detailed Please do a deep review focused on subtle correctness bugs and regressions that are easy for human reviewers to miss. Prioritize findings over summary. Specifically, look for: incorrect edge-case behavior |
|
@claude Specifically, look for: incorrect edge-case behavior |
|
Claude finished @drisspg's task in 4m 27s —— View job Deep Correctness Review — PR #4194Reviewed all 16 changed files focusing on subtle correctness bugs, edge cases, and regressions. Bug:
|
| Finding | Severity | Type |
|---|---|---|
hadamard=True (bool) in benchmark |
Medium | Type bug |
tl.debug_barrier() for production sync |
High | Correctness risk |
| Zero test coverage for hadamard paths | High | Testing gap |
kv_temp buffer reuse assumption |
Medium | Fragile correctness |
| Global memory for intra-block shuffle | Medium | Performance/correctness |
| Last-chunk loop overhead | Low | Performance |
_compute_num_chunks interface change |
Low | Refactoring |
str(HadamardMode.QKV) fragility |
Low | API design |
The core Hadamard butterfly math is correct. The main concerns are the reliance on tl.debug_barrier() for production synchronization, the complete absence of tests, and the type mismatch in the benchmark.
|
@drisspg Can you take another look? Just addressed some of the claude review comments:
A lot of the other comments I think aren't really applicable in my opinion or are wrong, but happy to look into it more if you think I should. |
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 25.0, |
There was a problem hiding this comment.
should we expect/check sqnr to be higher w/ hadamard?
There was a problem hiding this comment.
The tests here were more just making sure Hadamard was calculated correctly (with inverse and stuff). Because it's uniformly sampled, Hadamard doesn't actually improve SQNR. It only improves when there's outliers. I just added a separate set of tests to check that Hadamard improves accuracy when the inputs have outliers.
| try: | ||
| with torch.no_grad(): | ||
| out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False) | ||
| out_fp8 = fp8_fa3_rope_sdpa( |
Stack from ghstack (oldest at bottom):
Summary
Results
Single Attention Layer
LLaMA3 Prefill