add torch.compile test for Float8BlockwiseLinear#4187
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4187
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit 3292946 with merge base f11eff8 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@iamzainhuda i think there was a miscommunication, we don't want to directly wrap an individual triton custom op in torch.compile and test - we want to compile a full blockwise linear layer (this test) and ensure there are no graph breaks (fullgraph=True) and that numerics of outputs/grads undergo the pass with the same threshold testing as the eager mode tests. |
6457882 to
8548e82
Compare
that makes sense (esp w.r.t understanding graph breaks), updated the PR. added a compile test in blockwise_linear with a full layer, used a smaller shape for compile time sake. verified full graph compilation, numerical correctness and no erroneous recompilation across multiple calls. |
| out_features, | ||
| batch_size, | ||
| block_size, | ||
| compile_mode=True, |
There was a problem hiding this comment.
can we just set this "use compile" bool as a pytest.mark.parametrize parameter, and combine the eager and compile unit tests into one test (rather than duplicating just to change this bool?). to keep things simple and be consistent with the usual parttern in other torchao tests.
or is there any reason these are separated now?
There was a problem hiding this comment.
i had it separated for 1) compile test uses a smaller shape for compile time/runtime cost and 2) easier to reason about failures when a "test_linear_compile" is failing vs "test_linear_fwd_bwd" as i felt they were testing reasonably different things. so it'd make less sense to have them both as "test_linear_fwd_bwd". happy to move them together if you'd like!
There was a problem hiding this comment.
ok sure, sgtm. i think the compile time thing is a bit overkill for a unit test unless it's really long, but i agree the test failures will be more readable at a glance, since with parameterized use_compile bool the failed config just shows up as like test_this_thing[4096-4096-1-128-true] which gets confusing especially if we add more bool test params later
| if compile_mode: | ||
| with torch._dynamo.config.patch(trace_autograd_ops=True): | ||
| torch._dynamo.reset() | ||
| compiled_frame_counter = CompileCounterWithBackend("inductor") |
There was a problem hiding this comment.
this is cool, hadn't used this before
Summary
torch.compiletesting toFloat8BlockwiseLinearforward and backward passes, verifying fullgraph compilation, numerical correctness, and no recompilation across repeated calls with the same shapestest_blockwise_quant_linear_fwd_bwdinto a shared_run_blockwise_quant_linear_fwd_bwdhelper that supports both eager and compiled execution pathsTesting
test_blockwise_quant_linear_fwd_bwd): Parametrized acrossin_features=[4096],out_features=[128256],batch_size=[1, 8],block_size=[128]— validates forward SQNR >= 25.0 and backward grad SQNR >= 30.0 against a referencenn.Lineartest_blockwise_quant_linear_compile_fullgraph_fwd_bwd): Runs withtorch.compile(fullgraph=True)usingCompileCounterWithBackend("inductor")to assert: