Skip to content

Fixes INT4 quantization pack/unpack on TPU#22338

Merged
hertschuh merged 11 commits intokeras-team:masterfrom
JyotinderSingh:tpu-debug
Mar 3, 2026
Merged

Fixes INT4 quantization pack/unpack on TPU#22338
hertschuh merged 11 commits intokeras-team:masterfrom
JyotinderSingh:tpu-debug

Conversation

@JyotinderSingh
Copy link
Collaborator

@JyotinderSingh JyotinderSingh commented Mar 3, 2026

Summary

  • Fix pack_int4 producing corrupted values on TPU due to int8 bitwise left_shift overflow
  • Move pack_int4 bitwise operations from device-side (XLA) to numpy

Problem

8 test_int4_quantization_block_size tests (across Dense and EinsumDense) were failing on TPU with MSE values of 0.2–1.2 (threshold: 0.01), while passing on CPU.

Root cause: pack_int4 uses left_shift(int8, 4) to pack two nibbles into one byte. When the high nibble value is >= 8 (i.e., the original value was negative), this shift produces 128+, which overflows signed int8 (max 127). On CPU, JAX/numpy wraps correctly via two's complement. On TPU, XLA produces incorrect results, the hypothesis points to int8 -> int32 promotion during the bitwise op followed by a saturating (rather than truncating) narrowing cast back to int8. This corrupted some of the packed int4 weight values.

Note: I still need to reliably reproduce this issue outside of the quantization code path, will probably open a bug with XLA after that.

Fix

Moved all packing logic to numpy. Since pack_int4 is only called during quantization (not inference), there is no performance impact. Numpy correctly handles int8 overflow with two's complement wrapping.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces diagnostic tools and refactors the int4 quantization logic within Keras. The primary goal appears to be to aid in debugging and verifying the behavior of int4 quantization, particularly in the context of the Dense layer. It includes new tests to ensure the integrity of int4 data handling and kernel values, alongside a functional adjustment to how quantized kernels are processed during the forward pass. Additionally, several minor code formatting improvements were applied to lambda functions for consistency.

Highlights

  • Added Int4 Quantization Diagnostics: New diagnostic tests were introduced in dense_test.py to verify the correctness of int4 quantization. These tests include roundtrip checks for unpack_int4, validation of kernel values within the int4 range, and a comparison of different dequantization approaches.
  • Refactored Int4 Quantization Logic: The forward pass logic for the Dense layer in dense.py was modified to pre-calculate the float_kernel and adjust the matrix multiplication and scaling for int4 quantized inputs, improving the handling of quantized operations.
  • Minor Code Formatting: Consistent formatting was applied to lambda functions across several backend files (trainer.py, core.py for numpy and torch, and layer.py) by wrapping their bodies in parentheses for improved readability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/jax/trainer.py
    • Updated lambda function formatting within _get_distributed_iterator.
  • keras/src/backend/numpy/core.py
    • Updated lambda function formatting within while_loop.
  • keras/src/backend/torch/core.py
    • Updated lambda function formatting within while_loop.
  • keras/src/layers/core/dense.py
    • Refactored int4 quantization logic to pre-calculate float_kernel.
    • Adjusted the computation of x when inputs_quantizer is present.
  • keras/src/layers/core/dense_test.py
    • Added diagnostic tests to verify unpack_int4 roundtrip on device.
    • Added diagnostic tests to verify kernel values are within the int4 range.
    • Added diagnostic tests to compare the dequantize-first approach.
  • keras/src/layers/layer.py
    • Updated lambda function formatting within _initialize_tracker.
Ignored Files
  • Ignored by pattern: .github/workflows/** (1)
    • .github/workflows/tpu_tests.yml
Activity
  • The pull request is marked as "[Do not review]", indicating it is likely a work-in-progress or an internal diagnostic branch. No specific review comments or approvals have been recorded.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@JyotinderSingh JyotinderSingh added run_tpu_tests Trigger TPU tests on the PR kokoro:force-run and removed run_tpu_tests Trigger TPU tests on the PR labels Mar 3, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the int4 quantization logic in the Dense layer, which appears to improve numerical stability by performing dequantization before the matrix multiplication. The changes in dense.py are a good improvement. However, a significant amount of diagnostic and debugging code, including print statements, has been added to dense_test.py. This code should be removed before the pull request is merged. Other changes are minor stylistic improvements to satisfy linting rules.

Note: Security Review is unavailable for this PR.

@codecov-commenter
Copy link

codecov-commenter commented Mar 3, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.86%. Comparing base (0eb7db9) to head (4ada855).

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22338      +/-   ##
==========================================
- Coverage   82.87%   82.86%   -0.01%     
==========================================
  Files         594      594              
  Lines       65753    65751       -2     
  Branches    10276    10277       +1     
==========================================
- Hits        54492    54485       -7     
- Misses       8640     8646       +6     
+ Partials     2621     2620       -1     
Flag Coverage Δ
keras 82.69% <100.00%> (-0.01%) ⬇️
keras-jax 60.91% <100.00%> (-0.01%) ⬇️
keras-numpy 55.10% <100.00%> (-0.01%) ⬇️
keras-openvino 49.06% <0.00%> (+<0.01%) ⬆️
keras-tensorflow 62.15% <100.00%> (-0.01%) ⬇️
keras-torch 61.02% <100.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@JyotinderSingh JyotinderSingh reopened this Mar 3, 2026
@JyotinderSingh JyotinderSingh changed the title [Do not review] Adds diagnostics for tpu debugging Fixes INT4 quantization pack/unpack on TPU Mar 3, 2026
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 3, 2026
@hertschuh hertschuh merged commit 0d5c2f7 into keras-team:master Mar 3, 2026
16 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Mar 3, 2026
@JyotinderSingh JyotinderSingh deleted the tpu-debug branch March 3, 2026 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants