|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-Apache2 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Weight conversion between EvolutionaryScale ESMC and NVEsmc (TransformerEngine) formats. |
| 17 | +
|
| 18 | +The ESMC reference model uses: |
| 19 | +- QKV as a Sequential(LayerNorm, Linear) producing [Q||K||V] concatenated |
| 20 | +- QK LayerNorm over full d_model dimension (960) |
| 21 | +- Residue scaling: divides attn output and FFN output by sqrt(n_layers/36) |
| 22 | +- FFN as Sequential(LayerNorm, Linear, SwiGLU, Linear) |
| 23 | +
|
| 24 | +TE TransformerLayer uses: |
| 25 | +- Fused LayerNormLinear for QKV with interleaved weights [h1_q, h1_k, h1_v, h2_q, ...] |
| 26 | +- Per-head QK LayerNorm (head_dim=64) |
| 27 | +- No native residue scaling (absorbed into projection weights) |
| 28 | +- Fused LayerNormMLP |
| 29 | +""" |
| 30 | + |
| 31 | +import math |
| 32 | + |
| 33 | +import torch |
| 34 | +from modeling_esmc_te import NVEsmcConfig, NVEsmcForMaskedLM |
| 35 | + |
| 36 | + |
| 37 | +# Direct 1:1 weight mappings (no transforms needed) |
| 38 | +mapping = { |
| 39 | + "esmc.embed_tokens.weight": "esmc.embed_tokens.weight", |
| 40 | + # Per-layer attention LayerNorm |
| 41 | + "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_weight": "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_weight", |
| 42 | + "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_bias": "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_bias", |
| 43 | + # Per-layer MLP LayerNorm |
| 44 | + "esmc.layers.*.layernorm_mlp.layer_norm_weight": "esmc.layers.*.layernorm_mlp.layer_norm_weight", |
| 45 | + "esmc.layers.*.layernorm_mlp.layer_norm_bias": "esmc.layers.*.layernorm_mlp.layer_norm_bias", |
| 46 | + # Per-layer QKV weight |
| 47 | + "esmc.layers.*.self_attention.layernorm_qkv.weight": "esmc.layers.*.self_attention.layernorm_qkv.weight", |
| 48 | + # Per-layer attention output projection |
| 49 | + "esmc.layers.*.self_attention.proj.weight": "esmc.layers.*.self_attention.proj.weight", |
| 50 | + # Per-layer MLP weights |
| 51 | + "esmc.layers.*.layernorm_mlp.fc1_weight": "esmc.layers.*.layernorm_mlp.fc1_weight", |
| 52 | + "esmc.layers.*.layernorm_mlp.fc2_weight": "esmc.layers.*.layernorm_mlp.fc2_weight", |
| 53 | + # Per-layer QK norm |
| 54 | + "esmc.layers.*.self_attention.q_norm.weight": "esmc.layers.*.self_attention.q_norm.weight", |
| 55 | + "esmc.layers.*.self_attention.q_norm.bias": "esmc.layers.*.self_attention.q_norm.bias", |
| 56 | + "esmc.layers.*.self_attention.k_norm.weight": "esmc.layers.*.self_attention.k_norm.weight", |
| 57 | + "esmc.layers.*.self_attention.k_norm.bias": "esmc.layers.*.self_attention.k_norm.bias", |
| 58 | + # Final norm |
| 59 | + "esmc.norm.weight": "esmc.norm.weight", |
| 60 | + "esmc.norm.bias": "esmc.norm.bias", |
| 61 | + # Sequence head |
| 62 | + "sequence_head.dense.weight": "sequence_head.dense.weight", |
| 63 | + "sequence_head.dense.bias": "sequence_head.dense.bias", |
| 64 | + "sequence_head.decoder.layer_norm_weight": "sequence_head.decoder.layer_norm_weight", |
| 65 | + "sequence_head.decoder.layer_norm_bias": "sequence_head.decoder.layer_norm_bias", |
| 66 | + "sequence_head.decoder.weight": "sequence_head.decoder.weight", |
| 67 | + "sequence_head.decoder.bias": "sequence_head.decoder.bias", |
| 68 | +} |
| 69 | + |
| 70 | + |
| 71 | +def _reinterleave_qkv(weight, num_heads, head_dim): |
| 72 | + """Reinterleave QKV weight from [Q||K||V] to TE's interleaved format. |
| 73 | +
|
| 74 | + Input: [3*num_heads*head_dim, hidden_size] arranged as [Q, K, V] |
| 75 | + Output: [3*num_heads*head_dim, hidden_size] arranged as [h1_q, h1_k, h1_v, h2_q, ...] |
| 76 | + """ |
| 77 | + # Reshape to [3, num_heads, head_dim, hidden_size] |
| 78 | + qkv = weight.reshape(3, num_heads, head_dim, -1) |
| 79 | + # Transpose to [num_heads, 3, head_dim, hidden_size] |
| 80 | + qkv = qkv.permute(1, 0, 2, 3) |
| 81 | + # Flatten back to [3*num_heads*head_dim, hidden_size] |
| 82 | + return qkv.reshape(-1, weight.shape[-1]) |
| 83 | + |
| 84 | + |
| 85 | +def _deinterleave_qkv(weight, num_heads, head_dim): |
| 86 | + """Reverse of _reinterleave_qkv: from TE interleaved to [Q||K||V] concatenated.""" |
| 87 | + # Reshape to [num_heads, 3, head_dim, hidden_size] |
| 88 | + qkv = weight.reshape(num_heads, 3, head_dim, -1) |
| 89 | + # Transpose to [3, num_heads, head_dim, hidden_size] |
| 90 | + qkv = qkv.permute(1, 0, 2, 3) |
| 91 | + # Flatten back to [3*num_heads*head_dim, hidden_size] |
| 92 | + return qkv.reshape(-1, weight.shape[-1]) |
| 93 | + |
| 94 | + |
| 95 | +def convert_esmc_to_te(ref_state_dict: dict[str, torch.Tensor], config: NVEsmcConfig) -> NVEsmcForMaskedLM: |
| 96 | + """Convert EvolutionaryScale ESMC weights to NVEsmc (TransformerEngine) format. |
| 97 | +
|
| 98 | + This performs: |
| 99 | + 1. Key remapping from ESMC ref format to TE format |
| 100 | + 2. QKV weight reinterleaving for TE's fused attention |
| 101 | + 3. QK norm weight reshaping from [d_model] to per-head [head_dim] |
| 102 | + 4. Residue scaling absorption into output projection and fc2 weights |
| 103 | +
|
| 104 | + Args: |
| 105 | + ref_state_dict: State dict from the EvolutionaryScale ESMC model (.pth file). |
| 106 | + config: NVEsmcConfig for the target TE model. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + NVEsmcForMaskedLM with converted weights. |
| 110 | + """ |
| 111 | + num_heads = config.num_attention_heads |
| 112 | + head_dim = config.hidden_size // num_heads |
| 113 | + num_layers = config.num_hidden_layers |
| 114 | + hidden_size = config.hidden_size |
| 115 | + scale_factor = math.sqrt(num_layers / 36) |
| 116 | + |
| 117 | + te_state_dict = {} |
| 118 | + |
| 119 | + # Embedding |
| 120 | + te_state_dict["esmc.embed_tokens.weight"] = ref_state_dict["embed.weight"] |
| 121 | + |
| 122 | + for layer_idx in range(num_layers): |
| 123 | + ref_prefix = f"transformer.blocks.{layer_idx}" |
| 124 | + te_prefix = f"esmc.layers.{layer_idx}" |
| 125 | + |
| 126 | + # Attention LayerNorm (pre-QKV) |
| 127 | + te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_weight"] = ref_state_dict[ |
| 128 | + f"{ref_prefix}.attn.layernorm_qkv.0.weight" |
| 129 | + ] |
| 130 | + te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_bias"] = ref_state_dict[ |
| 131 | + f"{ref_prefix}.attn.layernorm_qkv.0.bias" |
| 132 | + ] |
| 133 | + |
| 134 | + # QKV weight: reinterleave from [Q||K||V] to TE's interleaved format |
| 135 | + qkv_weight = ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.1.weight"] |
| 136 | + te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.weight"] = _reinterleave_qkv( |
| 137 | + qkv_weight, num_heads, head_dim |
| 138 | + ) |
| 139 | + |
| 140 | + # QK norm: reshape from full d_model [960] to per-head [64] |
| 141 | + # ESMC applies LayerNorm(d_model) before reshape to heads. |
| 142 | + # TE applies per-head LayerNorm(head_dim). We take each head's portion. |
| 143 | + q_ln_weight = ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"] |
| 144 | + k_ln_weight = ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"] |
| 145 | + # Take the first head's portion as representative (all heads share same init) |
| 146 | + te_state_dict[f"{te_prefix}.self_attention.q_norm.weight"] = q_ln_weight[:head_dim] |
| 147 | + te_state_dict[f"{te_prefix}.self_attention.q_norm.bias"] = torch.zeros(head_dim, dtype=q_ln_weight.dtype) |
| 148 | + te_state_dict[f"{te_prefix}.self_attention.k_norm.weight"] = k_ln_weight[:head_dim] |
| 149 | + te_state_dict[f"{te_prefix}.self_attention.k_norm.bias"] = torch.zeros(head_dim, dtype=k_ln_weight.dtype) |
| 150 | + |
| 151 | + # Attention output projection: absorb residue scaling |
| 152 | + out_proj_weight = ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"] |
| 153 | + te_state_dict[f"{te_prefix}.self_attention.proj.weight"] = out_proj_weight / scale_factor |
| 154 | + |
| 155 | + # FFN LayerNorm (pre-MLP) |
| 156 | + te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_weight"] = ref_state_dict[f"{ref_prefix}.ffn.0.weight"] |
| 157 | + te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_bias"] = ref_state_dict[f"{ref_prefix}.ffn.0.bias"] |
| 158 | + |
| 159 | + # FFN fc1 (gate + up proj concatenated for SwiGLU) |
| 160 | + te_state_dict[f"{te_prefix}.layernorm_mlp.fc1_weight"] = ref_state_dict[f"{ref_prefix}.ffn.1.weight"] |
| 161 | + |
| 162 | + # FFN fc2 (down proj): absorb residue scaling |
| 163 | + fc2_weight = ref_state_dict[f"{ref_prefix}.ffn.3.weight"] |
| 164 | + te_state_dict[f"{te_prefix}.layernorm_mlp.fc2_weight"] = fc2_weight / scale_factor |
| 165 | + |
| 166 | + # Final LayerNorm |
| 167 | + te_state_dict["esmc.norm.weight"] = ref_state_dict["transformer.norm.weight"] |
| 168 | + # ESMC final norm has bias=False, but TE LayerNorm always has bias. Set to zeros. |
| 169 | + te_state_dict["esmc.norm.bias"] = torch.zeros(hidden_size, dtype=ref_state_dict["transformer.norm.weight"].dtype) |
| 170 | + |
| 171 | + # Sequence head (RegressionHead): Linear -> GELU -> LayerNorm -> Linear |
| 172 | + # ref: sequence_head.0 = Linear(960, 960) |
| 173 | + te_state_dict["sequence_head.dense.weight"] = ref_state_dict["sequence_head.0.weight"] |
| 174 | + te_state_dict["sequence_head.dense.bias"] = ref_state_dict["sequence_head.0.bias"] |
| 175 | + # ref: sequence_head.2 = LayerNorm(960), sequence_head.3 = Linear(960, 64) |
| 176 | + # TE LayerNormLinear fuses both |
| 177 | + te_state_dict["sequence_head.decoder.layer_norm_weight"] = ref_state_dict["sequence_head.2.weight"] |
| 178 | + te_state_dict["sequence_head.decoder.layer_norm_bias"] = ref_state_dict["sequence_head.2.bias"] |
| 179 | + te_state_dict["sequence_head.decoder.weight"] = ref_state_dict["sequence_head.3.weight"] |
| 180 | + te_state_dict["sequence_head.decoder.bias"] = ref_state_dict["sequence_head.3.bias"] |
| 181 | + |
| 182 | + # Build the TE model and load state dict |
| 183 | + with torch.device("meta"): |
| 184 | + model_te = NVEsmcForMaskedLM(config) |
| 185 | + |
| 186 | + target_state = model_te.state_dict() |
| 187 | + |
| 188 | + # Directly load the pre-transformed state dict |
| 189 | + for key in list(target_state.keys()): |
| 190 | + if key.endswith("_extra_state"): |
| 191 | + continue |
| 192 | + if key in te_state_dict: |
| 193 | + target_state[key] = te_state_dict[key] |
| 194 | + |
| 195 | + # Load into model |
| 196 | + model_te.load_state_dict(target_state, strict=False, assign=True) |
| 197 | + model_te.tie_weights() |
| 198 | + |
| 199 | + return model_te |
| 200 | + |
| 201 | + |
| 202 | +def convert_esmc_te_to_ref(model_te: NVEsmcForMaskedLM) -> dict[str, torch.Tensor]: |
| 203 | + """Convert NVEsmc (TransformerEngine) weights back to EvolutionaryScale ESMC format. |
| 204 | +
|
| 205 | + This reverses the transformations from convert_esmc_to_te: |
| 206 | + 1. QKV weight deinterleaving |
| 207 | + 2. QK norm weight expansion from per-head [head_dim] to [d_model] |
| 208 | + 3. Residue scaling removal from projection weights |
| 209 | +
|
| 210 | + Args: |
| 211 | + model_te: NVEsmcForMaskedLM model with TE weights. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + State dict in EvolutionaryScale ESMC format. |
| 215 | + """ |
| 216 | + config = model_te.config |
| 217 | + num_heads = config.num_attention_heads |
| 218 | + head_dim = config.hidden_size // num_heads |
| 219 | + num_layers = config.num_hidden_layers |
| 220 | + scale_factor = math.sqrt(num_layers / 36) |
| 221 | + |
| 222 | + te_sd = model_te.state_dict() |
| 223 | + ref_state_dict = {} |
| 224 | + |
| 225 | + # Embedding |
| 226 | + ref_state_dict["embed.weight"] = te_sd["esmc.embed_tokens.weight"] |
| 227 | + |
| 228 | + for layer_idx in range(num_layers): |
| 229 | + te_prefix = f"esmc.layers.{layer_idx}" |
| 230 | + ref_prefix = f"transformer.blocks.{layer_idx}" |
| 231 | + |
| 232 | + # Attention LayerNorm |
| 233 | + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.weight"] = te_sd[ |
| 234 | + f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_weight" |
| 235 | + ] |
| 236 | + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.bias"] = te_sd[ |
| 237 | + f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_bias" |
| 238 | + ] |
| 239 | + |
| 240 | + # QKV weight: deinterleave |
| 241 | + qkv_weight = te_sd[f"{te_prefix}.self_attention.layernorm_qkv.weight"] |
| 242 | + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.1.weight"] = _deinterleave_qkv( |
| 243 | + qkv_weight, num_heads, head_dim |
| 244 | + ) |
| 245 | + |
| 246 | + # QK norm: expand from per-head [64] to full d_model [960] |
| 247 | + q_norm_weight = te_sd[f"{te_prefix}.self_attention.q_norm.weight"] |
| 248 | + k_norm_weight = te_sd[f"{te_prefix}.self_attention.k_norm.weight"] |
| 249 | + ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"] = q_norm_weight.repeat(num_heads) |
| 250 | + ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"] = k_norm_weight.repeat(num_heads) |
| 251 | + |
| 252 | + # Attention output projection: reverse scaling |
| 253 | + ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"] = ( |
| 254 | + te_sd[f"{te_prefix}.self_attention.proj.weight"] * scale_factor |
| 255 | + ) |
| 256 | + |
| 257 | + # FFN LayerNorm |
| 258 | + ref_state_dict[f"{ref_prefix}.ffn.0.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_weight"] |
| 259 | + ref_state_dict[f"{ref_prefix}.ffn.0.bias"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_bias"] |
| 260 | + |
| 261 | + # FFN fc1 |
| 262 | + ref_state_dict[f"{ref_prefix}.ffn.1.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc1_weight"] |
| 263 | + |
| 264 | + # FFN fc2: reverse scaling |
| 265 | + ref_state_dict[f"{ref_prefix}.ffn.3.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc2_weight"] * scale_factor |
| 266 | + |
| 267 | + # Final LayerNorm (no bias in ref) |
| 268 | + ref_state_dict["transformer.norm.weight"] = te_sd["esmc.norm.weight"] |
| 269 | + |
| 270 | + # Sequence head |
| 271 | + ref_state_dict["sequence_head.0.weight"] = te_sd["sequence_head.dense.weight"] |
| 272 | + ref_state_dict["sequence_head.0.bias"] = te_sd["sequence_head.dense.bias"] |
| 273 | + ref_state_dict["sequence_head.2.weight"] = te_sd["sequence_head.decoder.layer_norm_weight"] |
| 274 | + ref_state_dict["sequence_head.2.bias"] = te_sd["sequence_head.decoder.layer_norm_bias"] |
| 275 | + ref_state_dict["sequence_head.3.weight"] = te_sd["sequence_head.decoder.weight"] |
| 276 | + ref_state_dict["sequence_head.3.bias"] = te_sd["sequence_head.decoder.bias"] |
| 277 | + |
| 278 | + return ref_state_dict |
0 commit comments