Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 223 additions & 12 deletions src/liquidonnx/lfm2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,66 @@

logger = logging.getLogger(__name__)

# === INT4 Block Quantization ===

INT4_BITS = 4
INT4_MAX = (1 << INT4_BITS) - 1 # 15, max value for unsigned 4-bit
DEFAULT_BLOCK_SIZE = 32
SCALE_EPS = 1e-10


def quantize_int4_block(
weight: np.ndarray, block_size: int = DEFAULT_BLOCK_SIZE
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Quantize weight tensor to INT4 with block-wise scales and zero points.

Args:
weight: FP32 weight tensor of shape [..., K] where K is quantized dimension
block_size: Number of elements per quantization block

Returns:
quant: UINT8 tensor with packed INT4 values (2 per byte)
scales: FP32 scales, one per block
zero_points: UINT8 packed zero points (2 per byte)
"""
*batch_dims, K = weight.shape
n_blocks = (K + block_size - 1) // block_size

pad_K = n_blocks * block_size
if pad_K != K:
pad_shape = list(weight.shape)
pad_shape[-1] = pad_K - K
weight = np.concatenate([weight, np.zeros(pad_shape, dtype=weight.dtype)], axis=-1)

weight_blocked = weight.reshape(*batch_dims, n_blocks, block_size)

w_min = weight_blocked.min(axis=-1, keepdims=True)
w_max = weight_blocked.max(axis=-1, keepdims=True)

scale = (w_max - w_min) / float(INT4_MAX)
scale = np.where(scale < SCALE_EPS, 1.0, scale)
zero_point = np.round(-w_min / scale).clip(0, INT4_MAX).astype(np.uint8)

# q = round(w/s + zp) to match community
quant = np.round(weight_blocked / scale + zero_point).clip(0, INT4_MAX).astype(np.uint8)

# Pack two INT4 values into one UINT8 (low nibble first)
quant_packed = quant[..., 0::2] | (quant[..., 1::2] << 4)

scales = scale.squeeze(-1).astype(np.float32)

# Pack zero points
zero_point = zero_point.squeeze(-1)
if n_blocks % 2 == 1:
zp_shape = list(zero_point.shape)
zp_shape[-1] = 1
zero_point = np.concatenate([zero_point, np.zeros(zp_shape, dtype=np.uint8)], axis=-1)
zp_packed = zero_point[..., 0::2] | (zero_point[..., 1::2] << 4)

quant_final = quant_packed.reshape(*batch_dims, -1)

return quant_final, scales, zp_packed


@dataclass
class LFM2Config:
Expand Down Expand Up @@ -86,7 +146,12 @@ class LFM2Builder(ONNXBuilderBase):
"""

def __init__(
self, config: LFM2Config, use_integrated_rope: bool = False, vl_naming: bool = False
self,
config: LFM2Config,
use_integrated_rope: bool = False,
vl_naming: bool = False,
use_q4: bool = False,
q4_block_size: int = DEFAULT_BLOCK_SIZE,
):
"""
Args:
Expand All @@ -96,17 +161,114 @@ def __init__(
vl_naming: Use VL-style node naming (Shape, Gather_1) instead of
LFM2-style (Shape_for_slice, Gather_for_slice). Community VL and LFM2
models use different conventions.
use_q4: Use INT4 quantized embedding (GatherBlockQuantized) and lm_head
(MatMulNBits). Other MatMul layers are left as FP32 for post-export
quantization.
q4_block_size: Block size for INT4 quantization (default: 32).
"""
super().__init__()
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
self.use_integrated_rope = use_integrated_rope
self.vl_naming = vl_naming
self.use_q4 = use_q4
self.q4_block_size = q4_block_size

# Categorize layers
self.conv_indices = [i for i, t in enumerate(config.layer_types) if t == "conv"]
self.attn_indices = [i for i, t in enumerate(config.layer_types) if t == "full_attention"]

# === Q4 Quantization Methods ===

def _quantize_for_matmul_nbits(
self, weight: np.ndarray, name: str
) -> tuple[str, str, str, int, int]:
"""Quantize weight for MatMulNBits operator.

Args:
weight: FP32 weight tensor of shape [K, N] (already transposed for MatMul)
name: Base name for initializers

Returns:
Tuple of (quant_name, scales_name, zp_name, K, N)
"""
K, N = weight.shape
block_size = self.q4_block_size

weight_t = weight.T # [N, K]
quant, scales, zp = quantize_int4_block(weight_t, block_size)

n_blocks = (K + block_size - 1) // block_size
quant_3d = quant.reshape(N, n_blocks, block_size // 2)

quant_name = f"{name}_quant"
scales_name = f"{name}_scales"
zp_name = f"{name}_zp"

self.add_initializer(quant_name, quant_3d, dtype=np.uint8)
self.add_initializer(scales_name, scales)
self.add_initializer(zp_name, zp, dtype=np.uint8)

return quant_name, scales_name, zp_name, K, N

def make_matmul_nbits(
self, input_name: str, weight: np.ndarray, name: str, output_name: str
) -> str:
"""Create MatMulNBits node for INT4 quantized linear layer.

Args:
input_name: Input tensor name
weight: Weight matrix [K, N] (already transposed for MatMul)
name: Base name for the operation
output_name: Output tensor name
"""
quant_name, scales_name, zp_name, K, N = self._quantize_for_matmul_nbits(weight, name)

return self.make_node(
"MatMulNBits",
[input_name, quant_name, scales_name, zp_name],
[output_name],
domain="com.microsoft",
K=K,
N=N,
bits=4,
block_size=self.q4_block_size,
)

def make_gather_block_quantized(
self, weight: np.ndarray, indices_name: str, name: str, output_name: str
) -> str:
"""Create GatherBlockQuantized node for INT4 quantized embedding lookup.

Args:
weight: Embedding weight [vocab_size, hidden_size]
indices_name: Input token IDs tensor name
name: Base name for initializers
output_name: Output tensor name
"""
block_size = self.q4_block_size

quant, scales, zp = quantize_int4_block(weight, block_size)

quant_name = f"{name}_quant"
scales_name = f"{name}_scales"
zp_name = f"{name}_zp"

self.add_initializer(quant_name, quant, dtype=np.uint8)
self.add_initializer(scales_name, scales)
self.add_initializer(zp_name, zp, dtype=np.uint8)

return self.make_node(
"GatherBlockQuantized",
[quant_name, indices_name, scales_name, zp_name],
[output_name],
domain="com.microsoft",
bits=4,
block_size=block_size,
gather_axis=0,
quantize_axis=1,
)

def make_simple_layernorm(
self, input_name: str, weight_name: str, path: str, name: str = None
) -> str:
Expand Down Expand Up @@ -318,7 +480,17 @@ def build_outputs(self):
)

def build_embedding(self) -> str:
self.add_initializer("model.embed_tokens.weight", self.weights["model.embed_tokens.weight"])
embed_weight = self.weights["model.embed_tokens.weight"]

if self.use_q4:
return self.make_gather_block_quantized(
embed_weight,
"input_ids",
"model_embed_tokens_weight",
"/model/embed_tokens/GatherBlockQuantized/output_0",
)

self.add_initializer("model.embed_tokens.weight", embed_weight)
return self.make_node(
"Gather",
["model.embed_tokens.weight", "input_ids"],
Expand Down Expand Up @@ -830,8 +1002,46 @@ def build_lm_head(self, hidden_state: str) -> str:
name=f"/model/layers.{num_layers}/final_norm_layernorm/SkipLayerNorm",
)

# LM head with tied embeddings (community approach)
# Transpose embed_tokens at runtime instead of storing a copy (saves 256MB)
if self.use_q4:
# Q4: Use MatMulNBits for lm_head with shared embedding weights
embed_quant_name = "model_embed_tokens_weight_quant"
embed_quant = None
for init in self.initializers:
if init.name == embed_quant_name:
embed_quant = onnx.numpy_helper.to_array(init)
break

if embed_quant is None:
raise ValueError("Embedding quant not found - build_embedding must be called first")

vocab_size = embed_quant.shape[0]
K = self.config.hidden_size
n_blocks = (K + self.q4_block_size - 1) // self.q4_block_size

# Reshape to 3D for MatMulNBits: [N, n_blocks, block_size/2]
embed_quant_matmul = embed_quant.reshape(vocab_size, n_blocks, self.q4_block_size // 2)
self.add_initializer(
"model_embed_tokens_weight_quant_matmul", embed_quant_matmul, dtype=np.uint8
)

# Reuse scales and zero points from embedding
return self.make_node(
"MatMulNBits",
[
normed,
"model_embed_tokens_weight_quant_matmul",
"model_embed_tokens_weight_scales",
"model_embed_tokens_weight_zp",
],
["logits"],
domain="com.microsoft",
K=K,
N=vocab_size,
bits=4,
block_size=self.q4_block_size,
)

# FP32: Transpose embed_tokens at runtime instead of storing a copy
# embed_tokens.weight [vocab, hidden] → [hidden, vocab]
lm_head_weight = self.make_node(
"Transpose",
Expand Down Expand Up @@ -876,11 +1086,11 @@ def build_value_info(self):
self.add_value_info(f"{mask_prefix}/{gather_name}/Cast/output_0", TensorProto.INT32, [])

# === Embedding output ===
self.add_value_info(
"/model/embed_tokens/Gather/output_0",
TensorProto.FLOAT,
["batch_size", "sequence_length", H],
)
if self.use_q4:
embed_output = "/model/embed_tokens/GatherBlockQuantized/output_0"
else:
embed_output = "/model/embed_tokens/Gather/output_0"
self.add_value_info(embed_output, TensorProto.FLOAT, ["batch_size", "sequence_length", H])

# === Per-layer outputs ===
for layer_idx in range(num_layers):
Expand Down Expand Up @@ -1083,9 +1293,10 @@ def build_value_info(self):
TensorProto.FLOAT,
["batch_size", "sequence_length", H],
)
self.add_value_info(
"/lm_head/Transpose/output_0", TensorProto.FLOAT, [H, self.config.vocab_size]
)
if not self.use_q4:
self.add_value_info(
"/lm_head/Transpose/output_0", TensorProto.FLOAT, [H, self.config.vocab_size]
)

def load_weights(self, model_path: str):
"""Load weights from HuggingFace model."""
Expand Down
Loading
Loading