[X86] Add convert_element_type to smooth quant pattern#3784
[X86] Add convert_element_type to smooth quant pattern#3784cyxlily wants to merge 15 commits intopytorch:mainfrom
Conversation
Signed-off-by: Cui, Lily <lily.cui@intel.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3784
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
This PR extends the smooth quantization pattern matching to support operations that include convert_element_type nodes. The changes enable the pattern matcher to recognize and handle quantization patterns with additional type conversion operations.
Changes:
- Added
convert_aparameter toget_pattern_no_biasfunction to generate patterns with extraconvert_element_typenodes - Created new pattern variants (
pattern_no_bias_1_c1,pattern_no_bias_1_c2, etc.) to match different combinations of conversion operations - Updated validation logic to accept additional match node counts (8, 9, 12) for the new patterns
- Modified keyword argument handling to use
x_scale_dtypeand makedtypeoptional with fallback
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Xia-Weiwen
left a comment
There was a problem hiding this comment.
Thanks for the PR. Please also add a UT.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Moved to the other pr. Signed-off-by: Cui, Lily <lily.cui@intel.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _validate_pattern(match: Match): | ||
| if len(match.nodes) not in [4, 5, 6, 7, 10]: | ||
| # Valid node counts correspond to different pattern variations: | ||
| # 4: pattern1_with_no_outer_or_act_reshape (int_mm + convert + mul + mul) | ||
| # 6: pattern_no_bias_1 (reshape + int_mm + convert + mul + mul + reshape) | ||
| # 7: pattern_with_bias_1 (pattern_no_bias_1 + add) | ||
| # 8: pattern_no_bias_1_with_output_convert (pattern_no_bias_1 with dot scaled + output convert) | ||
| # 9: pattern_with_bias_1_with_output_convert (pattern_with_bias_1 with dot scaled + output convert) | ||
| if len(match.nodes) not in [4, 6, 7, 8, 9]: | ||
| return False |
| # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist | ||
| # When torch.compile'ing with dynamic=False, they don't exist |
| def get_pattern_no_bias(reshape_a: bool = True, convert_a: bool = False): | ||
| int_mm_pattern = CallFunction( |
No description provided.