Enables the per_tensor lowering patterns for weight per_packing#2391
Enables the per_tensor lowering patterns for weight per_packing#2391choudhary-devang wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2391
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New FailuresAs of commit bd15048 with merge base ce07646 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c698531 to
67d4a79
Compare
|
Hi @jerryzh168, @fadara01, @Xia-Weiwen can you please review this pr |
|
Thanks, can you add some tests in https://github.com/pytorch/ao/tree/main/test/quantization/pt2e |
67d4a79 to
d863085
Compare
|
Hi @jerryzh168, |
2caf61d to
e51e9ec
Compare
|
Thanks for your PR! |
Hi @fadara01, Thanks for the response. to recreate the experiment quant script current setup |
|
Ahhh that's amazing! I remember doing a PoC for this exact thing back in the day and I had to tweak qlinear/qconv, hence my question. |
|
Hi @jerryzh168, @fadara01, can you please approve and merge this change. |
|
@pytorchbot rebase |
1 similar comment
|
@pytorchbot rebase |
b5a6358 to
ab75a9b
Compare
|
Hi @jerryzh168, @fadara01, can you please approve and merge this change. |
ab75a9b to
ad1ff8d
Compare
|
Hi @jerryzh168. @fadara01, can you please approve and merge this change. |
| X86InductorQuantizer, | ||
| ) | ||
|
|
||
| if TORCH_VERSION_AT_LEAST_2_7: |
There was a problem hiding this comment.
this is deprecated btw, please use
| ) | ||
|
|
||
| if TORCH_VERSION_AT_LEAST_2_7: | ||
| torch._inductor.config.pre_grad_custom_pass = quant_lift_up |
There was a problem hiding this comment.
what happens when multiple backend set this one?
There was a problem hiding this comment.
what happens when multiple backend set this one?
previously, the last writer won, so we introduced chain instead of overwriting, so multiple backend can safely coexist.
details:
previously :-
torch._inductor.config.pre_grad_custom_pass = quant_lift_up
A single global callable meant the last assignment silently overwrite any prior pass.
Change:-
Chain ARM’s pass after any existing pass, instead of overwriting it.
This guarantees both passes run in a deterministic order.
added an helper function to chain rather overwrite
def _chain_pregrad_pass(new_pass):
prev = getattr(torch._inductor.config, "pre_grad_custom_pass", None)
if prev is None or prev is new_pass:
return new_pass
def _chained(gm):
# run previous first, then ours (conservative ordering)
prev(gm)
new_pass(gm)
return _chained
replacing direct pass with chaining:
if torch_version_at_least("2.8.0"):
torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)
Now both pass (prev -> Arm) will execute.
ad1ff8d to
c9417fa
Compare
|
Hi @jerryzh168, can you check this once |
| from torchao.quantization.pt2e.inductor_passes.arm import ( | ||
| _register_quantization_weight_pack_pass, | ||
| ) | ||
| from torchao.quantization.pt2e.inductor_passes.x86 import ( |
There was a problem hiding this comment.
this seems to be introducing dependency between arm and x86, is it possible to remove?
There was a problem hiding this comment.
if you are really reusing this, might be better to refactor this to a separate file and have both x86 and arm depend on it I think
There was a problem hiding this comment.
I’ve removed the ARM→x86 import and refactored quant_lift_up into a shared file(utils.py) so both backends depend on a neutral module instead of each other.
path:
ao/torchao/quantization/pt2e/inductor_passes/utils.py
| _register_quantization_weight_pack_pass, | ||
| ) | ||
| from torchao.quantization.pt2e.inductor_passes.x86 import ( | ||
| quant_lift_up, |
There was a problem hiding this comment.
I thought this is a prev_pass?
There was a problem hiding this comment.
In the chaining helper,prev is the existing torch._inductor.config.pre_grad_custom_pass (if any). We now set:
torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)
which composes prev (if present) with quant_lift_up (new). If prev is quant_lift_up, we skip wrapping to avoid double running.
c9417fa to
9f01b51
Compare
6be5e87 to
9f01b51
Compare
9f01b51 to
bd15048
Compare
This Pr is an extension of #2139 pr,
Major changes:
1)Introduced lowering pattern for "per_tensor" quantized weights.
2) Modified the original api
get_default_arm_inductor_quantization_configto add user choice of using "per_tensor" and "per_channel" granularity in model weight's quantization.supported shapes:
Tested and verified for different models:
Example script for refence:
Results
All time in sec, Taken on Aws Graviton 3E 32 core Instance
Pip list
cc: @jerryzh168, @fadara01, @Xia-Weiwen