Fixes INT4 quantization pack/unpack on TPU#22338
Fixes INT4 quantization pack/unpack on TPU#22338hertschuh merged 11 commits intokeras-team:masterfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces diagnostic tools and refactors the int4 quantization logic within Keras. The primary goal appears to be to aid in debugging and verifying the behavior of int4 quantization, particularly in the context of the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Ignored Files
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the int4 quantization logic in the Dense layer, which appears to improve numerical stability by performing dequantization before the matrix multiplication. The changes in dense.py are a good improvement. However, a significant amount of diagnostic and debugging code, including print statements, has been added to dense_test.py. This code should be removed before the pull request is merged. Other changes are minor stylistic improvements to satisfy linting rules.
Note: Security Review is unavailable for this PR.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22338 +/- ##
==========================================
- Coverage 82.87% 82.86% -0.01%
==========================================
Files 594 594
Lines 65753 65751 -2
Branches 10276 10277 +1
==========================================
- Hits 54492 54485 -7
- Misses 8640 8646 +6
+ Partials 2621 2620 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This reverts commit ce16774.
Summary
pack_int4producing corrupted values on TPU due toint8bitwiseleft_shiftoverflowpack_int4bitwise operations from device-side (XLA) to numpyProblem
8
test_int4_quantization_block_sizetests (acrossDenseandEinsumDense) were failing on TPU with MSE values of 0.2–1.2 (threshold: 0.01), while passing on CPU.Root cause:
pack_int4usesleft_shift(int8, 4)to pack two nibbles into one byte. When the high nibble value is >= 8 (i.e., the original value was negative), this shift produces 128+, which overflows signedint8(max 127). On CPU, JAX/numpy wraps correctly via two's complement. On TPU, XLA produces incorrect results, the hypothesis points to int8 -> int32 promotion during the bitwise op followed by a saturating (rather than truncating) narrowing cast back to int8. This corrupted some of the packed int4 weight values.Note: I still need to reliably reproduce this issue outside of the quantization code path, will probably open a bug with XLA after that.
Fix
Moved all packing logic to numpy. Since
pack_int4is only called during quantization (not inference), there is no performance impact. Numpy correctly handles int8 overflow with two's complement wrapping.