Skip to content

add torch.compile test for Float8BlockwiseLinear#4187

Merged
iamzainhuda merged 8 commits intomainfrom
torch-compile-kernel-tests
Mar 31, 2026
Merged

add torch.compile test for Float8BlockwiseLinear#4187
iamzainhuda merged 8 commits intomainfrom
torch-compile-kernel-tests

Conversation

@iamzainhuda
Copy link
Copy Markdown
Contributor

@iamzainhuda iamzainhuda commented Mar 26, 2026

Summary

  • Add torch.compile testing to Float8BlockwiseLinear forward and backward passes, verifying fullgraph compilation, numerical correctness, and no recompilation across repeated calls with the same shapes
  • Refactored existing test_blockwise_quant_linear_fwd_bwd into a shared _run_blockwise_quant_linear_fwd_bwd helper that supports both eager and compiled execution paths
  • Fixes backward pass assertions to correctly compare gradients (input grad vs input grad, weight grad vs weight grad)

Testing

pytest test/prototype/blockwise_fp8_training/test_blockwise_linear.py
  • Eager mode (test_blockwise_quant_linear_fwd_bwd): Parametrized across in_features=[4096], out_features=[128256], batch_size=[1, 8], block_size=[128] — validates forward SQNR >= 25.0 and backward grad SQNR >= 30.0 against a reference nn.Linear
  • Compile mode (test_blockwise_quant_linear_compile_fullgraph_fwd_bwd): Runs with torch.compile(fullgraph=True) using CompileCounterWithBackend("inductor") to assert:
    • The model traces into a single compiled frame (no graph breaks)
    • No recompilation on a second forward/backward call with the same input shapes
    • Numerical correctness matches the eager reference (same SQNR thresholds)
  • Both paths check for NaN-free outputs and gradients

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4187

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 3292946 with merge base f11eff8 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@iamzainhuda iamzainhuda added the module: training quantize_ api training flow label Mar 26, 2026
@danielvegamyhre
Copy link
Copy Markdown
Contributor

@iamzainhuda i think there was a miscommunication, we don't want to directly wrap an individual triton custom op in torch.compile and test - we want to compile a full blockwise linear layer (this test) and ensure there are no graph breaks (fullgraph=True) and that numerics of outputs/grads undergo the pass with the same threshold testing as the eager mode tests.

JiwaniZakir

This comment was marked as spam.

@iamzainhuda iamzainhuda force-pushed the torch-compile-kernel-tests branch from 6457882 to 8548e82 Compare March 30, 2026 20:19
@iamzainhuda
Copy link
Copy Markdown
Contributor Author

@iamzainhuda i think there was a miscommunication, we don't want to directly wrap an individual triton custom op in torch.compile and test - we want to compile a full blockwise linear layer (this test) and ensure there are no graph breaks (fullgraph=True) and that numerics of outputs/grads undergo the pass with the same threshold testing as the eager mode tests.

that makes sense (esp w.r.t understanding graph breaks), updated the PR. added a compile test in blockwise_linear with a full layer, used a smaller shape for compile time sake. verified full graph compilation, numerical correctness and no erroneous recompilation across multiple calls.

@iamzainhuda iamzainhuda changed the title add torch.compile to blockwise quantized kernel unit tests add torch.compile test for Float8BlockwiseLinear Mar 30, 2026
out_features,
batch_size,
block_size,
compile_mode=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just set this "use compile" bool as a pytest.mark.parametrize parameter, and combine the eager and compile unit tests into one test (rather than duplicating just to change this bool?). to keep things simple and be consistent with the usual parttern in other torchao tests.

or is there any reason these are separated now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had it separated for 1) compile test uses a smaller shape for compile time/runtime cost and 2) easier to reason about failures when a "test_linear_compile" is failing vs "test_linear_fwd_bwd" as i felt they were testing reasonably different things. so it'd make less sense to have them both as "test_linear_fwd_bwd". happy to move them together if you'd like!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sure, sgtm. i think the compile time thing is a bit overkill for a unit test unless it's really long, but i agree the test failures will be more readable at a glance, since with parameterized use_compile bool the failed config just shows up as like test_this_thing[4096-4096-1-128-true] which gets confusing especially if we add more bool test params later

if compile_mode:
with torch._dynamo.config.patch(trace_autograd_ops=True):
torch._dynamo.reset()
compiled_frame_counter = CompileCounterWithBackend("inductor")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cool, hadn't used this before

@iamzainhuda iamzainhuda merged commit 3ad1067 into main Mar 31, 2026
15 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants