Skip to content

Commit 2bd275f

Browse files
committed
initial commit for esm-c model code
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent c66bed5 commit 2bd275f

File tree

12 files changed

+3205
-0
lines changed

12 files changed

+3205
-0
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Weight conversion between EvolutionaryScale ESMC and NVEsmc (TransformerEngine) formats.
17+
18+
The ESMC reference model uses:
19+
- QKV as a Sequential(LayerNorm, Linear) producing [Q||K||V] concatenated
20+
- QK LayerNorm over full d_model dimension (960)
21+
- Residue scaling: divides attn output and FFN output by sqrt(n_layers/36)
22+
- FFN as Sequential(LayerNorm, Linear, SwiGLU, Linear)
23+
24+
TE TransformerLayer uses:
25+
- Fused LayerNormLinear for QKV with interleaved weights [h1_q, h1_k, h1_v, h2_q, ...]
26+
- Per-head QK LayerNorm (head_dim=64)
27+
- No native residue scaling (absorbed into projection weights)
28+
- Fused LayerNormMLP
29+
"""
30+
31+
import math
32+
33+
import torch
34+
from modeling_esmc_te import NVEsmcConfig, NVEsmcForMaskedLM
35+
36+
37+
# Direct 1:1 weight mappings (no transforms needed)
38+
mapping = {
39+
"esmc.embed_tokens.weight": "esmc.embed_tokens.weight",
40+
# Per-layer attention LayerNorm
41+
"esmc.layers.*.self_attention.layernorm_qkv.layer_norm_weight": "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
42+
"esmc.layers.*.self_attention.layernorm_qkv.layer_norm_bias": "esmc.layers.*.self_attention.layernorm_qkv.layer_norm_bias",
43+
# Per-layer MLP LayerNorm
44+
"esmc.layers.*.layernorm_mlp.layer_norm_weight": "esmc.layers.*.layernorm_mlp.layer_norm_weight",
45+
"esmc.layers.*.layernorm_mlp.layer_norm_bias": "esmc.layers.*.layernorm_mlp.layer_norm_bias",
46+
# Per-layer QKV weight
47+
"esmc.layers.*.self_attention.layernorm_qkv.weight": "esmc.layers.*.self_attention.layernorm_qkv.weight",
48+
# Per-layer attention output projection
49+
"esmc.layers.*.self_attention.proj.weight": "esmc.layers.*.self_attention.proj.weight",
50+
# Per-layer MLP weights
51+
"esmc.layers.*.layernorm_mlp.fc1_weight": "esmc.layers.*.layernorm_mlp.fc1_weight",
52+
"esmc.layers.*.layernorm_mlp.fc2_weight": "esmc.layers.*.layernorm_mlp.fc2_weight",
53+
# Per-layer QK norm
54+
"esmc.layers.*.self_attention.q_norm.weight": "esmc.layers.*.self_attention.q_norm.weight",
55+
"esmc.layers.*.self_attention.q_norm.bias": "esmc.layers.*.self_attention.q_norm.bias",
56+
"esmc.layers.*.self_attention.k_norm.weight": "esmc.layers.*.self_attention.k_norm.weight",
57+
"esmc.layers.*.self_attention.k_norm.bias": "esmc.layers.*.self_attention.k_norm.bias",
58+
# Final norm
59+
"esmc.norm.weight": "esmc.norm.weight",
60+
"esmc.norm.bias": "esmc.norm.bias",
61+
# Sequence head
62+
"sequence_head.dense.weight": "sequence_head.dense.weight",
63+
"sequence_head.dense.bias": "sequence_head.dense.bias",
64+
"sequence_head.decoder.layer_norm_weight": "sequence_head.decoder.layer_norm_weight",
65+
"sequence_head.decoder.layer_norm_bias": "sequence_head.decoder.layer_norm_bias",
66+
"sequence_head.decoder.weight": "sequence_head.decoder.weight",
67+
"sequence_head.decoder.bias": "sequence_head.decoder.bias",
68+
}
69+
70+
71+
def _reinterleave_qkv(weight, num_heads, head_dim):
72+
"""Reinterleave QKV weight from [Q||K||V] to TE's interleaved format.
73+
74+
Input: [3*num_heads*head_dim, hidden_size] arranged as [Q, K, V]
75+
Output: [3*num_heads*head_dim, hidden_size] arranged as [h1_q, h1_k, h1_v, h2_q, ...]
76+
"""
77+
# Reshape to [3, num_heads, head_dim, hidden_size]
78+
qkv = weight.reshape(3, num_heads, head_dim, -1)
79+
# Transpose to [num_heads, 3, head_dim, hidden_size]
80+
qkv = qkv.permute(1, 0, 2, 3)
81+
# Flatten back to [3*num_heads*head_dim, hidden_size]
82+
return qkv.reshape(-1, weight.shape[-1])
83+
84+
85+
def _deinterleave_qkv(weight, num_heads, head_dim):
86+
"""Reverse of _reinterleave_qkv: from TE interleaved to [Q||K||V] concatenated."""
87+
# Reshape to [num_heads, 3, head_dim, hidden_size]
88+
qkv = weight.reshape(num_heads, 3, head_dim, -1)
89+
# Transpose to [3, num_heads, head_dim, hidden_size]
90+
qkv = qkv.permute(1, 0, 2, 3)
91+
# Flatten back to [3*num_heads*head_dim, hidden_size]
92+
return qkv.reshape(-1, weight.shape[-1])
93+
94+
95+
def convert_esmc_to_te(ref_state_dict: dict[str, torch.Tensor], config: NVEsmcConfig) -> NVEsmcForMaskedLM:
96+
"""Convert EvolutionaryScale ESMC weights to NVEsmc (TransformerEngine) format.
97+
98+
This performs:
99+
1. Key remapping from ESMC ref format to TE format
100+
2. QKV weight reinterleaving for TE's fused attention
101+
3. QK norm weight reshaping from [d_model] to per-head [head_dim]
102+
4. Residue scaling absorption into output projection and fc2 weights
103+
104+
Args:
105+
ref_state_dict: State dict from the EvolutionaryScale ESMC model (.pth file).
106+
config: NVEsmcConfig for the target TE model.
107+
108+
Returns:
109+
NVEsmcForMaskedLM with converted weights.
110+
"""
111+
num_heads = config.num_attention_heads
112+
head_dim = config.hidden_size // num_heads
113+
num_layers = config.num_hidden_layers
114+
hidden_size = config.hidden_size
115+
scale_factor = math.sqrt(num_layers / 36)
116+
117+
te_state_dict = {}
118+
119+
# Embedding
120+
te_state_dict["esmc.embed_tokens.weight"] = ref_state_dict["embed.weight"]
121+
122+
for layer_idx in range(num_layers):
123+
ref_prefix = f"transformer.blocks.{layer_idx}"
124+
te_prefix = f"esmc.layers.{layer_idx}"
125+
126+
# Attention LayerNorm (pre-QKV)
127+
te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_weight"] = ref_state_dict[
128+
f"{ref_prefix}.attn.layernorm_qkv.0.weight"
129+
]
130+
te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_bias"] = ref_state_dict[
131+
f"{ref_prefix}.attn.layernorm_qkv.0.bias"
132+
]
133+
134+
# QKV weight: reinterleave from [Q||K||V] to TE's interleaved format
135+
qkv_weight = ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.1.weight"]
136+
te_state_dict[f"{te_prefix}.self_attention.layernorm_qkv.weight"] = _reinterleave_qkv(
137+
qkv_weight, num_heads, head_dim
138+
)
139+
140+
# QK norm: reshape from full d_model [960] to per-head [64]
141+
# ESMC applies LayerNorm(d_model) before reshape to heads.
142+
# TE applies per-head LayerNorm(head_dim). We take each head's portion.
143+
q_ln_weight = ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"]
144+
k_ln_weight = ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"]
145+
# Take the first head's portion as representative (all heads share same init)
146+
te_state_dict[f"{te_prefix}.self_attention.q_norm.weight"] = q_ln_weight[:head_dim]
147+
te_state_dict[f"{te_prefix}.self_attention.q_norm.bias"] = torch.zeros(head_dim, dtype=q_ln_weight.dtype)
148+
te_state_dict[f"{te_prefix}.self_attention.k_norm.weight"] = k_ln_weight[:head_dim]
149+
te_state_dict[f"{te_prefix}.self_attention.k_norm.bias"] = torch.zeros(head_dim, dtype=k_ln_weight.dtype)
150+
151+
# Attention output projection: absorb residue scaling
152+
out_proj_weight = ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"]
153+
te_state_dict[f"{te_prefix}.self_attention.proj.weight"] = out_proj_weight / scale_factor
154+
155+
# FFN LayerNorm (pre-MLP)
156+
te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_weight"] = ref_state_dict[f"{ref_prefix}.ffn.0.weight"]
157+
te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_bias"] = ref_state_dict[f"{ref_prefix}.ffn.0.bias"]
158+
159+
# FFN fc1 (gate + up proj concatenated for SwiGLU)
160+
te_state_dict[f"{te_prefix}.layernorm_mlp.fc1_weight"] = ref_state_dict[f"{ref_prefix}.ffn.1.weight"]
161+
162+
# FFN fc2 (down proj): absorb residue scaling
163+
fc2_weight = ref_state_dict[f"{ref_prefix}.ffn.3.weight"]
164+
te_state_dict[f"{te_prefix}.layernorm_mlp.fc2_weight"] = fc2_weight / scale_factor
165+
166+
# Final LayerNorm
167+
te_state_dict["esmc.norm.weight"] = ref_state_dict["transformer.norm.weight"]
168+
# ESMC final norm has bias=False, but TE LayerNorm always has bias. Set to zeros.
169+
te_state_dict["esmc.norm.bias"] = torch.zeros(hidden_size, dtype=ref_state_dict["transformer.norm.weight"].dtype)
170+
171+
# Sequence head (RegressionHead): Linear -> GELU -> LayerNorm -> Linear
172+
# ref: sequence_head.0 = Linear(960, 960)
173+
te_state_dict["sequence_head.dense.weight"] = ref_state_dict["sequence_head.0.weight"]
174+
te_state_dict["sequence_head.dense.bias"] = ref_state_dict["sequence_head.0.bias"]
175+
# ref: sequence_head.2 = LayerNorm(960), sequence_head.3 = Linear(960, 64)
176+
# TE LayerNormLinear fuses both
177+
te_state_dict["sequence_head.decoder.layer_norm_weight"] = ref_state_dict["sequence_head.2.weight"]
178+
te_state_dict["sequence_head.decoder.layer_norm_bias"] = ref_state_dict["sequence_head.2.bias"]
179+
te_state_dict["sequence_head.decoder.weight"] = ref_state_dict["sequence_head.3.weight"]
180+
te_state_dict["sequence_head.decoder.bias"] = ref_state_dict["sequence_head.3.bias"]
181+
182+
# Build the TE model and load state dict
183+
with torch.device("meta"):
184+
model_te = NVEsmcForMaskedLM(config)
185+
186+
target_state = model_te.state_dict()
187+
188+
# Directly load the pre-transformed state dict
189+
for key in list(target_state.keys()):
190+
if key.endswith("_extra_state"):
191+
continue
192+
if key in te_state_dict:
193+
target_state[key] = te_state_dict[key]
194+
195+
# Load into model
196+
model_te.load_state_dict(target_state, strict=False, assign=True)
197+
model_te.tie_weights()
198+
199+
return model_te
200+
201+
202+
def convert_esmc_te_to_ref(model_te: NVEsmcForMaskedLM) -> dict[str, torch.Tensor]:
203+
"""Convert NVEsmc (TransformerEngine) weights back to EvolutionaryScale ESMC format.
204+
205+
This reverses the transformations from convert_esmc_to_te:
206+
1. QKV weight deinterleaving
207+
2. QK norm weight expansion from per-head [head_dim] to [d_model]
208+
3. Residue scaling removal from projection weights
209+
210+
Args:
211+
model_te: NVEsmcForMaskedLM model with TE weights.
212+
213+
Returns:
214+
State dict in EvolutionaryScale ESMC format.
215+
"""
216+
config = model_te.config
217+
num_heads = config.num_attention_heads
218+
head_dim = config.hidden_size // num_heads
219+
num_layers = config.num_hidden_layers
220+
scale_factor = math.sqrt(num_layers / 36)
221+
222+
te_sd = model_te.state_dict()
223+
ref_state_dict = {}
224+
225+
# Embedding
226+
ref_state_dict["embed.weight"] = te_sd["esmc.embed_tokens.weight"]
227+
228+
for layer_idx in range(num_layers):
229+
te_prefix = f"esmc.layers.{layer_idx}"
230+
ref_prefix = f"transformer.blocks.{layer_idx}"
231+
232+
# Attention LayerNorm
233+
ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.weight"] = te_sd[
234+
f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_weight"
235+
]
236+
ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.bias"] = te_sd[
237+
f"{te_prefix}.self_attention.layernorm_qkv.layer_norm_bias"
238+
]
239+
240+
# QKV weight: deinterleave
241+
qkv_weight = te_sd[f"{te_prefix}.self_attention.layernorm_qkv.weight"]
242+
ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.1.weight"] = _deinterleave_qkv(
243+
qkv_weight, num_heads, head_dim
244+
)
245+
246+
# QK norm: expand from per-head [64] to full d_model [960]
247+
q_norm_weight = te_sd[f"{te_prefix}.self_attention.q_norm.weight"]
248+
k_norm_weight = te_sd[f"{te_prefix}.self_attention.k_norm.weight"]
249+
ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"] = q_norm_weight.repeat(num_heads)
250+
ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"] = k_norm_weight.repeat(num_heads)
251+
252+
# Attention output projection: reverse scaling
253+
ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"] = (
254+
te_sd[f"{te_prefix}.self_attention.proj.weight"] * scale_factor
255+
)
256+
257+
# FFN LayerNorm
258+
ref_state_dict[f"{ref_prefix}.ffn.0.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_weight"]
259+
ref_state_dict[f"{ref_prefix}.ffn.0.bias"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_bias"]
260+
261+
# FFN fc1
262+
ref_state_dict[f"{ref_prefix}.ffn.1.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc1_weight"]
263+
264+
# FFN fc2: reverse scaling
265+
ref_state_dict[f"{ref_prefix}.ffn.3.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc2_weight"] * scale_factor
266+
267+
# Final LayerNorm (no bias in ref)
268+
ref_state_dict["transformer.norm.weight"] = te_sd["esmc.norm.weight"]
269+
270+
# Sequence head
271+
ref_state_dict["sequence_head.0.weight"] = te_sd["sequence_head.dense.weight"]
272+
ref_state_dict["sequence_head.0.bias"] = te_sd["sequence_head.dense.bias"]
273+
ref_state_dict["sequence_head.2.weight"] = te_sd["sequence_head.decoder.layer_norm_weight"]
274+
ref_state_dict["sequence_head.2.bias"] = te_sd["sequence_head.decoder.layer_norm_bias"]
275+
ref_state_dict["sequence_head.3.weight"] = te_sd["sequence_head.decoder.weight"]
276+
ref_state_dict["sequence_head.3.bias"] = te_sd["sequence_head.decoder.bias"]
277+
278+
return ref_state_dict
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Export ESMC checkpoint to HuggingFace-compatible format with TransformerEngine layers.
17+
18+
This script:
19+
1. Loads the EvolutionaryScale ESMC-300M pretrained weights
20+
2. Converts them to TransformerEngine format
21+
3. Saves the converted model for use with HuggingFace's `AutoModel.from_pretrained()`
22+
"""
23+
24+
import json
25+
import shutil
26+
from pathlib import Path
27+
28+
import convert
29+
from modeling_esmc_te import AUTO_MAP, NVEsmcConfig
30+
31+
32+
def export_esmc_checkpoint(export_path: Path):
33+
"""Export the ESMC-300M model to a TE checkpoint.
34+
35+
Args:
36+
export_path: Directory to save the exported checkpoint.
37+
"""
38+
from esm.pretrained import ESMC_300M_202412
39+
40+
# Load reference model on CPU to save GPU memory
41+
ref_model = ESMC_300M_202412(device="cpu", use_flash_attn=False)
42+
ref_state_dict = ref_model.state_dict()
43+
del ref_model
44+
45+
# Create config matching ESMC-300M architecture
46+
config = NVEsmcConfig(
47+
vocab_size=64,
48+
hidden_size=960,
49+
num_hidden_layers=30,
50+
num_attention_heads=15,
51+
intermediate_size=2560,
52+
)
53+
54+
# Convert and save
55+
model_te = convert.convert_esmc_to_te(ref_state_dict, config)
56+
model_te.to("cpu")
57+
model_te.save_pretrained(export_path)
58+
59+
# Patch the config with auto_map
60+
with open(export_path / "config.json") as f:
61+
config_json = json.load(f)
62+
63+
config_json["auto_map"] = AUTO_MAP
64+
65+
with open(export_path / "config.json", "w") as f:
66+
json.dump(config_json, f, indent=2, sort_keys=True)
67+
68+
# Copy modeling file for standalone loading
69+
shutil.copy("modeling_esmc_te.py", export_path / "modeling_esmc_te.py")
70+
71+
# Save tokenizer
72+
from esm.tokenization import EsmSequenceTokenizer
73+
74+
tokenizer = EsmSequenceTokenizer()
75+
tokenizer.save_pretrained(export_path)
76+
77+
78+
if __name__ == "__main__":
79+
export_esmc_checkpoint(Path("checkpoint_export"))

0 commit comments

Comments
 (0)