1+ import re
12from collections import defaultdict
3+
24from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
35
46from llmcompressor .entrypoints .model_free .helpers import (
57 MatchedNamesSet ,
68 match_names_set_eager ,
79)
810
9-
10- def match_name (name : str , pattern : str ) -> bool :
11- """Pattern matching for tensor names. Handles 're:' prefix for regex patterns."""
12- import re
13- if pattern .startswith ('re:' ):
14- # Regex pattern - strip 're:' prefix and match
15- regex = pattern [3 :]
16- return re .match (regex , name ) is not None
17- else :
18- # Glob-style pattern
19- import fnmatch
20- return fnmatch .fnmatch (name , pattern )
21-
22-
23-
24- def build_inverse_weights_map (
25- shard_name : str ,
26- weight_map : dict [str , str ],
27- model_files : dict [str , str ],
28- ) -> dict [str , list [str ]]:
29- """
30- For a given output shard, precompute exactly which tensors need to be
31- loaded from which source files — including fused partner tensors that
32- live in other shards.
33-
34- This moves fused partner discovery out of the per-process runtime and
35- into the job-building phase, avoiding redundant re-discovery and enabling
36- cleaner process function signatures.
37-
38- For example, given:
39- shard0: [q_proj.weight, ...]
40- shard1: [k_proj.weight, v_proj.weight, ...]
41-
42- The inverse_weights_map for shard0's job would be:
43- {
44- "/path/to/shard0.safetensors": ["q_proj.weight", ...],
45- "/path/to/shard1.safetensors": ["k_proj.weight", "v_proj.weight"],
46- }
47-
48- :param shard_name: the shard filename this job will process and save
49- :param weight_map: mapping of tensor name -> shard filename (from index.json)
50- :param model_files: mapping of shard filename -> resolved absolute path
51- :return: dict mapping resolved source file path -> list of tensor names to load
52- """
53- # These are now module-level since function is in microscale.py
54- # DEFAULT_FUSED_MAPPINGS and get_fused_names are available at module scope
55-
56- # Tensors natively belonging to this shard
57- native_tensors = [t for t , s in weight_map .items () if s == shard_name ]
58-
59- # Check if all fused sets are already complete within this shard
60- _ , unmatched_sets = get_fused_names (native_tensors )
61-
62- # Start with native tensors grouped by their source file
63- result : dict [str , list [str ]] = defaultdict (list )
64- own_resolved = model_files [shard_name ]
65- result [own_resolved ] = list (native_tensors )
66-
67- if not unmatched_sets :
68- return dict (result )
69-
70- # For each unmatched fused set, find partner tensors in other shards
71- all_patterns = [p for mapping in DEFAULT_FUSED_MAPPINGS for p in mapping ]
72-
73- for unmatched in unmatched_sets :
74- present_names = {v for v in unmatched .values () if v is not None }
75- layer_prefixes = {name .rsplit ("." , 2 )[0 ] for name in present_names }
76-
77- for tensor_name , tensor_shard in weight_map .items ():
78- if tensor_shard == shard_name :
79- continue # already in native tensors
80- resolved = model_files .get (tensor_shard )
81- if resolved is None :
82- continue
83- candidate_prefix = tensor_name .rsplit ("." , 2 )[0 ]
84- if candidate_prefix not in layer_prefixes :
85- continue
86- if any (match_name (tensor_name , p ) for p in all_patterns ):
87- if tensor_name not in result [resolved ]:
88- result [resolved ].append (tensor_name )
89-
90- return dict (result )
91-
92-
93-
9411__all__ = [
95- ' build_inverse_weights_map' ,
96- ' is_microscale_scheme' ,
97- ' get_fused_names' ,
98- ' DEFAULT_FUSED_MAPPINGS' ,
12+ " build_inverse_weights_map" ,
13+ " is_microscale_scheme" ,
14+ " get_fused_names" ,
15+ " DEFAULT_FUSED_MAPPINGS" ,
9916]
10017
18+ # Mapping of primary weight pattern -> list of partner weight patterns.
19+ # The shard owning the primary tensor is responsible for fetching its partners.
20+ # This prevents double reads: each fused set is fetched exactly once, by the
21+ # shard that owns the primary (e.g. q_proj fetches k_proj + v_proj).
22+ #
23+ # Patterns use a named group (?P<prefix>...) so partner names can be
24+ # constructed by substituting the matched prefix via:
25+ # partner.format(prefix=match.group("prefix"))
26+ DEFAULT_FUSED_MAPPINGS : dict [str , list [str ]] = {
27+ # Attention q/k/v fusion: q_proj is primary
28+ r"^(?P<prefix>.+?)\.(?P<attn>attn|attention|self_attn|self_attention)"
29+ r"\.q_proj\.weight$" : [
30+ r"{prefix}.{attn}.k_proj.weight" ,
31+ r"{prefix}.{attn}.v_proj.weight" ,
32+ ],
33+ # MLA attention fusion: wq_a is primary
34+ r"^(?P<prefix>.+?)\.(?P<attn>attn|attention|self_attn)\.wq_a\.weight$" : [
35+ r"{prefix}.{attn}.wkv_a_with_mqa.weight" ,
36+ ],
37+ # MLP gate/up fusion: gate_proj is primary
38+ r"^(?P<prefix>.+?)\.(?P<mlp>mlp|feed_forward)\.gate_proj\.weight$" : [
39+ r"{prefix}.{mlp}.up_proj.weight" ,
40+ ],
41+ # MoE w1/w3 fusion: w1 is primary
42+ r"^(?P<prefix>.+?)\.w1\.weight$" : [
43+ r"{prefix}.w3.weight" ,
44+ ],
45+ }
10146
102- DEFAULT_FUSED_MAPPINGS = [
47+ # List-of-lists format used by get_fused_names and validate.py
48+ _DEFAULT_FUSED_MAPPINGS_LIST = [
10349 [
10450 r"re:.*(attn|attention)\.q_proj\.weight$" ,
10551 r"re:.*(attn|attention)\.k_proj\.weight$" ,
@@ -124,11 +70,65 @@ def get_fused_names(
12470) -> tuple [list [MatchedNamesSet ], list [MatchedNamesSet ]]:
12571 matched = []
12672 unmatched = []
127- for mapping in DEFAULT_FUSED_MAPPINGS :
73+ for mapping in _DEFAULT_FUSED_MAPPINGS_LIST :
12874 _matched , _unmatched = match_names_set_eager (tensor_names , mapping )
129-
13075 matched .extend (_matched )
13176 if _unmatched is not None :
13277 unmatched .append (_unmatched )
133-
13478 return matched , unmatched
79+
80+
81+ def build_inverse_weights_map (
82+ shard_name : str ,
83+ weight_map : dict [str , str ],
84+ model_files : dict [str , str ],
85+ ) -> dict [str , list [str ]]:
86+ """
87+ For a given output shard, precompute exactly which tensors to load from
88+ which source files — including fused partner tensors from other shards.
89+
90+ Uses DEFAULT_FUSED_MAPPINGS with primary->partners structure to ensure
91+ only the shard owning the primary tensor fetches its partners, preventing
92+ double reads when fused weights span multiple shards.
93+
94+ Example — given:
95+ shard0: [q_proj.weight, ...] <- primary owner
96+ shard1: [k_proj.weight, v_proj.weight, ...] <- partners
97+
98+ Only shard0's inverse_weights_map will include shard1's tensors.
99+ Shard1's job loads only its own native tensors.
100+
101+ :param shard_name: the shard filename this job will process and save
102+ :param weight_map: tensor name -> shard filename (from safetensors.index.json)
103+ :param model_files: shard filename -> resolved absolute path
104+ :return: {resolved_file_path: [tensor_names_to_load]}
105+ """
106+ own_resolved = model_files [shard_name ]
107+ native_tensors = [t for t , s in weight_map .items () if s == shard_name ]
108+
109+ inverse_weights_map : dict [str , list [str ]] = defaultdict (list )
110+ inverse_weights_map [own_resolved ] = list (native_tensors )
111+
112+ # For each native tensor that matches a primary pattern, fetch its partners
113+ for name in native_tensors :
114+ for primary_pattern , partner_templates in DEFAULT_FUSED_MAPPINGS .items ():
115+ match = re .match (primary_pattern , name )
116+ if match is None :
117+ continue
118+
119+ # Build partner names using named groups from the match
120+ for partner_template in partner_templates :
121+ partner_name = partner_template .format (** match .groupdict ())
122+
123+ partner_shard = weight_map .get (partner_name )
124+ if partner_shard is None or partner_shard == shard_name :
125+ continue # same shard or not found
126+
127+ partner_resolved = model_files .get (partner_shard )
128+ if partner_resolved is None :
129+ continue
130+
131+ if partner_name not in inverse_weights_map [partner_resolved ]:
132+ inverse_weights_map [partner_resolved ].append (partner_name )
133+
134+ return dict (inverse_weights_map )
0 commit comments