Skip to content
70 changes: 34 additions & 36 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,48 +707,47 @@ def pack_int4(arr, axis=0, dtype="int8"):
f"{backend.standardize_dtype(arr.dtype)}."
)

rank = getattr(arr.shape, "rank", None) or len(arr.shape)

# Perform packing in numpy. Packing is only called during
# quantization (not inference), and numpy correctly handles int8
# overflow in bitwise operations. Some accelerators (e.g. TPU) may
# produce incorrect results for int8 left_shift that overflows, so
# using numpy avoids device-specific issues.
arr_np = ops.convert_to_numpy(arr)
np_dtype = np.dtype(dtype)

rank = len(arr_np.shape)
if axis < 0:
axis += rank

# 1. Bring `axis` to the front.
perm = [axis] + [i for i in range(rank) if i != axis]
inv_perm = [perm.index(i) for i in range(rank)]
transposed = ops.transpose(arr, perm)

# 2. Pad to even length.
rows = ops.shape(transposed)[0]
needs_pad = ops.equal(ops.mod(rows, 2), 1)
# Move the pack axis to the front for uniform handling.
arr_np = np.moveaxis(arr_np, axis, 0)

# Always append one zero row so the tensor shape is static for JAX. If no
# padding is actually needed, we'll slice it away later.
zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, ...)
padded_full = ops.concatenate([transposed, zero_row], axis=0)

# Number of valid rows after (possible) padding:
# rows + (1 if needs_pad else 0)
rows_packed = rows + ops.cast(needs_pad, "int32")
# Pad to even length along the front axis.
n = arr_np.shape[0]
if n % 2 == 1:
pad_shape = (1,) + arr_np.shape[1:]
arr_np = np.concatenate(
[arr_np, np.zeros(pad_shape, dtype=arr_np.dtype)], axis=0
)

# Slice to keep only the valid rows. This keeps the shape rank static while
# allowing the row count to be dynamic.
padded = padded_full[:rows_packed, ...]
# Group in pairs and pack nibbles.
low = arr_np[::2]
high = arr_np[1::2]

# 3-4. Group in pairs and pack.
low = padded[::2, ...]
high = padded[1::2, ...]
mask = np.array(0x0F, dtype=np_dtype)
low_u = np.bitwise_and(low.astype(np_dtype), mask)
high_u = np.bitwise_and(high.astype(np_dtype), mask)

mask = ops.array(0x0F, dtype=dtype)
low_u = ops.bitwise_and(low, mask)
high_u = ops.bitwise_and(high, mask)
packed_np = np.bitwise_or(
low_u, np.left_shift(high_u, np.array(4, dtype=np_dtype))
)
packed_np = packed_np.astype(np_dtype)

packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
packed = ops.cast(packed, dtype)
# Move the pack axis back to its original position.
packed_np = np.moveaxis(packed_np, 0, axis)

# 5-6. Restore shape.
packed = ops.transpose(packed, inv_perm) # back to original order
orig_len = rows # number of slices before padding
return packed, ops.shape(packed), orig_len
packed = ops.convert_to_tensor(packed_np)
return packed, tuple(packed_np.shape), n


@keras_export("keras.quantizers.unpack_int4")
Expand Down Expand Up @@ -852,9 +851,8 @@ def to_signed(x):
if axis < 0:
axis += rank

# Fast path for the most common case in Dense layers
# Fast path for axis==0 (common case in Dense layers)
if axis == 0 and rank == 2:
# The result of the bitwise op is a wider dtype (e.g., int32).
mask = ops.array(0x0F, dtype=packed.dtype)
low_unpacked = ops.bitwise_and(packed, mask)
high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)
Expand All @@ -866,7 +864,7 @@ def to_signed(x):
low_final = ops.cast(low_unpacked, dtype)
high_final = ops.cast(high_unpacked, dtype)

# Interleave and reshape
# Interleave along axis 0 and reshape
stacked = ops.stack([low_final, high_final], axis=1)
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))

Expand Down