Skip to content

Commit 728d84a

Browse files
committed
feat: implement DEFAULT_FUSED_MAPPINGS refactor and fix post-rebase issues
Per @kylesayrs review: - DEFAULT_FUSED_MAPPINGS refactored to {primary_pattern: [partner_templates]} so only the primary-owning shard fetches its partners, preventing double reads for cross-shard fused weight sets - build_inverse_weights_map uses re.match on primary patterns with named group substitution to construct partner names exactly as Kyle suggested - process_file_microscale_scheme: remove assert on unmatched fused sets — non-primary shards legitimately have k/v without q since only the primary shard fetches partners Post-rebase fixes: - __init__.py / save_utils.py: fix import path for local dev compatibility (compressed_tensors.utils.safetensors_load instead of compressed_tensors.entrypoints.convert.file_utils) - microscale.py: fix line too long in DEFAULT_FUSED_MAPPINGS Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent 1792bb7 commit 728d84a

File tree

4 files changed

+97
-100
lines changed

4 files changed

+97
-100
lines changed

src/llmcompressor/entrypoints/model_free/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99
Converter,
1010
exec_jobs,
1111
)
12-
from compressed_tensors.entrypoints.convert.file_utils import (
12+
from compressed_tensors.quantization import QuantizationScheme
13+
from compressed_tensors.utils.safetensors_load import (
1314
get_checkpoint_files,
1415
is_weights_file,
1516
)
16-
from compressed_tensors.quantization import QuantizationScheme
1717
from loguru import logger
1818

19-
20-
2119
from llmcompressor.entrypoints.model_free.helpers import (
2220
find_safetensors_index_file,
2321
gpu_if_available,
Lines changed: 93 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,51 @@
1+
import re
12
from collections import defaultdict
3+
24
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
35

46
from llmcompressor.entrypoints.model_free.helpers import (
57
MatchedNamesSet,
68
match_names_set_eager,
79
)
810

9-
10-
def match_name(name: str, pattern: str) -> bool:
11-
"""Pattern matching for tensor names. Handles 're:' prefix for regex patterns."""
12-
import re
13-
if pattern.startswith('re:'):
14-
# Regex pattern - strip 're:' prefix and match
15-
regex = pattern[3:]
16-
return re.match(regex, name) is not None
17-
else:
18-
# Glob-style pattern
19-
import fnmatch
20-
return fnmatch.fnmatch(name, pattern)
21-
22-
23-
24-
def build_inverse_weights_map(
25-
shard_name: str,
26-
weight_map: dict[str, str],
27-
model_files: dict[str, str],
28-
) -> dict[str, list[str]]:
29-
"""
30-
For a given output shard, precompute exactly which tensors need to be
31-
loaded from which source files — including fused partner tensors that
32-
live in other shards.
33-
34-
This moves fused partner discovery out of the per-process runtime and
35-
into the job-building phase, avoiding redundant re-discovery and enabling
36-
cleaner process function signatures.
37-
38-
For example, given:
39-
shard0: [q_proj.weight, ...]
40-
shard1: [k_proj.weight, v_proj.weight, ...]
41-
42-
The inverse_weights_map for shard0's job would be:
43-
{
44-
"/path/to/shard0.safetensors": ["q_proj.weight", ...],
45-
"/path/to/shard1.safetensors": ["k_proj.weight", "v_proj.weight"],
46-
}
47-
48-
:param shard_name: the shard filename this job will process and save
49-
:param weight_map: mapping of tensor name -> shard filename (from index.json)
50-
:param model_files: mapping of shard filename -> resolved absolute path
51-
:return: dict mapping resolved source file path -> list of tensor names to load
52-
"""
53-
# These are now module-level since function is in microscale.py
54-
# DEFAULT_FUSED_MAPPINGS and get_fused_names are available at module scope
55-
56-
# Tensors natively belonging to this shard
57-
native_tensors = [t for t, s in weight_map.items() if s == shard_name]
58-
59-
# Check if all fused sets are already complete within this shard
60-
_, unmatched_sets = get_fused_names(native_tensors)
61-
62-
# Start with native tensors grouped by their source file
63-
result: dict[str, list[str]] = defaultdict(list)
64-
own_resolved = model_files[shard_name]
65-
result[own_resolved] = list(native_tensors)
66-
67-
if not unmatched_sets:
68-
return dict(result)
69-
70-
# For each unmatched fused set, find partner tensors in other shards
71-
all_patterns = [p for mapping in DEFAULT_FUSED_MAPPINGS for p in mapping]
72-
73-
for unmatched in unmatched_sets:
74-
present_names = {v for v in unmatched.values() if v is not None}
75-
layer_prefixes = {name.rsplit(".", 2)[0] for name in present_names}
76-
77-
for tensor_name, tensor_shard in weight_map.items():
78-
if tensor_shard == shard_name:
79-
continue # already in native tensors
80-
resolved = model_files.get(tensor_shard)
81-
if resolved is None:
82-
continue
83-
candidate_prefix = tensor_name.rsplit(".", 2)[0]
84-
if candidate_prefix not in layer_prefixes:
85-
continue
86-
if any(match_name(tensor_name, p) for p in all_patterns):
87-
if tensor_name not in result[resolved]:
88-
result[resolved].append(tensor_name)
89-
90-
return dict(result)
91-
92-
93-
9411
__all__ = [
95-
'build_inverse_weights_map',
96-
'is_microscale_scheme',
97-
'get_fused_names',
98-
'DEFAULT_FUSED_MAPPINGS',
12+
"build_inverse_weights_map",
13+
"is_microscale_scheme",
14+
"get_fused_names",
15+
"DEFAULT_FUSED_MAPPINGS",
9916
]
10017

18+
# Mapping of primary weight pattern -> list of partner weight patterns.
19+
# The shard owning the primary tensor is responsible for fetching its partners.
20+
# This prevents double reads: each fused set is fetched exactly once, by the
21+
# shard that owns the primary (e.g. q_proj fetches k_proj + v_proj).
22+
#
23+
# Patterns use a named group (?P<prefix>...) so partner names can be
24+
# constructed by substituting the matched prefix via:
25+
# partner.format(prefix=match.group("prefix"))
26+
DEFAULT_FUSED_MAPPINGS: dict[str, list[str]] = {
27+
# Attention q/k/v fusion: q_proj is primary
28+
r"^(?P<prefix>.+?)\.(?P<attn>attn|attention|self_attn|self_attention)"
29+
r"\.q_proj\.weight$": [
30+
r"{prefix}.{attn}.k_proj.weight",
31+
r"{prefix}.{attn}.v_proj.weight",
32+
],
33+
# MLA attention fusion: wq_a is primary
34+
r"^(?P<prefix>.+?)\.(?P<attn>attn|attention|self_attn)\.wq_a\.weight$": [
35+
r"{prefix}.{attn}.wkv_a_with_mqa.weight",
36+
],
37+
# MLP gate/up fusion: gate_proj is primary
38+
r"^(?P<prefix>.+?)\.(?P<mlp>mlp|feed_forward)\.gate_proj\.weight$": [
39+
r"{prefix}.{mlp}.up_proj.weight",
40+
],
41+
# MoE w1/w3 fusion: w1 is primary
42+
r"^(?P<prefix>.+?)\.w1\.weight$": [
43+
r"{prefix}.w3.weight",
44+
],
45+
}
10146

102-
DEFAULT_FUSED_MAPPINGS = [
47+
# List-of-lists format used by get_fused_names and validate.py
48+
_DEFAULT_FUSED_MAPPINGS_LIST = [
10349
[
10450
r"re:.*(attn|attention)\.q_proj\.weight$",
10551
r"re:.*(attn|attention)\.k_proj\.weight$",
@@ -124,11 +70,65 @@ def get_fused_names(
12470
) -> tuple[list[MatchedNamesSet], list[MatchedNamesSet]]:
12571
matched = []
12672
unmatched = []
127-
for mapping in DEFAULT_FUSED_MAPPINGS:
73+
for mapping in _DEFAULT_FUSED_MAPPINGS_LIST:
12874
_matched, _unmatched = match_names_set_eager(tensor_names, mapping)
129-
13075
matched.extend(_matched)
13176
if _unmatched is not None:
13277
unmatched.append(_unmatched)
133-
13478
return matched, unmatched
79+
80+
81+
def build_inverse_weights_map(
82+
shard_name: str,
83+
weight_map: dict[str, str],
84+
model_files: dict[str, str],
85+
) -> dict[str, list[str]]:
86+
"""
87+
For a given output shard, precompute exactly which tensors to load from
88+
which source files — including fused partner tensors from other shards.
89+
90+
Uses DEFAULT_FUSED_MAPPINGS with primary->partners structure to ensure
91+
only the shard owning the primary tensor fetches its partners, preventing
92+
double reads when fused weights span multiple shards.
93+
94+
Example — given:
95+
shard0: [q_proj.weight, ...] <- primary owner
96+
shard1: [k_proj.weight, v_proj.weight, ...] <- partners
97+
98+
Only shard0's inverse_weights_map will include shard1's tensors.
99+
Shard1's job loads only its own native tensors.
100+
101+
:param shard_name: the shard filename this job will process and save
102+
:param weight_map: tensor name -> shard filename (from safetensors.index.json)
103+
:param model_files: shard filename -> resolved absolute path
104+
:return: {resolved_file_path: [tensor_names_to_load]}
105+
"""
106+
own_resolved = model_files[shard_name]
107+
native_tensors = [t for t, s in weight_map.items() if s == shard_name]
108+
109+
inverse_weights_map: dict[str, list[str]] = defaultdict(list)
110+
inverse_weights_map[own_resolved] = list(native_tensors)
111+
112+
# For each native tensor that matches a primary pattern, fetch its partners
113+
for name in native_tensors:
114+
for primary_pattern, partner_templates in DEFAULT_FUSED_MAPPINGS.items():
115+
match = re.match(primary_pattern, name)
116+
if match is None:
117+
continue
118+
119+
# Build partner names using named groups from the match
120+
for partner_template in partner_templates:
121+
partner_name = partner_template.format(**match.groupdict())
122+
123+
partner_shard = weight_map.get(partner_name)
124+
if partner_shard is None or partner_shard == shard_name:
125+
continue # same shard or not found
126+
127+
partner_resolved = model_files.get(partner_shard)
128+
if partner_resolved is None:
129+
continue
130+
131+
if partner_name not in inverse_weights_map[partner_resolved]:
132+
inverse_weights_map[partner_resolved].append(partner_name)
133+
134+
return dict(inverse_weights_map)

src/llmcompressor/entrypoints/model_free/save_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
)
1111
from compressed_tensors.config import CompressionFormat
1212
from compressed_tensors.entrypoints.convert import Converter
13-
from compressed_tensors.entrypoints.convert.file_utils import find_config_path
1413
from compressed_tensors.quantization import (
1514
QuantizationConfig,
1615
QuantizationScheme,
1716
QuantizationStatus,
1817
)
18+
from compressed_tensors.utils.safetensors_load import find_config_path
1919
from loguru import logger
2020
from pydantic import ValidationError
2121

src/llmcompressor/entrypoints/model_free/validate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme
3131
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
3232
if input_dynamic is not True or output_dynamic is not True:
3333
raise ValueError(
34-
"Model Free PTQ cannot calibrate activations. "
35-
"Please use `oneshot` instead."
34+
"Model Free PTQ cannot calibrate activations. Please use `oneshot` instead."
3635
)
3736

3837
# override with static observers

0 commit comments

Comments
 (0)