Skip to content

Commit c51b8b5

Browse files
committed
adds low precision stuff
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent d2d962a commit c51b8b5

File tree

6 files changed

+280
-64
lines changed

6 files changed

+280
-64
lines changed

bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ dataset:
1717
wandb_init_args:
1818
name: ???
1919

20+
# TransformerEngine FP8 config
21+
fp8_config:
22+
enabled: false
23+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
24+
fp8_format: "HYBRID"
25+
fp8_recipe_kwargs: {}
26+
27+
fp4_config:
28+
enabled: false
29+
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
30+
fp4_format: "E2M1"
31+
fp4_recipe_kwargs: {}
32+
2033
# Optimizer config
2134
adamw_kwargs:
2235
lr: 4e-4
@@ -40,3 +53,14 @@ checkpoint:
4053

4154
logger:
4255
frequency: 100
56+
57+
quant_stats_config:
58+
enabled: false
59+
quant_stats_file: ./fp8_debugging_stats.yaml
60+
quant_log_dir: ./log_quant_stats
61+
log_to_wandb: false
62+
63+
# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
64+
fp8_layers: null
65+
fp4_layers: null
66+
use_fp32_master_weights: null

bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py

Lines changed: 154 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919

2020
import json
2121
import math
22+
import warnings
23+
from contextlib import nullcontext
2224
from dataclasses import asdict, dataclass
23-
from typing import Optional
25+
from typing import ContextManager, Optional
2426

2527
import torch
2628
import torch.nn as nn
29+
import transformer_engine.common.recipe
2730
import transformer_engine.pytorch
2831
from torch.nn import CrossEntropyLoss
2932
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -50,6 +53,9 @@ class CodonFMConfig:
5053
# TE-specific options
5154
qkv_weight_interleaved: bool = True
5255
fuse_qkv_params: bool = True
56+
# Layer-wise precision options
57+
layer_precision: list[str | None] | None = None
58+
use_quantized_model_init: bool = False
5359

5460
def __post_init__(self):
5561
"""Validate configuration."""
@@ -60,6 +66,15 @@ def __post_init__(self):
6066
)
6167
if self.hidden_act not in ("gelu", "relu", "silu"):
6268
raise ValueError(f"hidden_act must be one of: gelu, relu, silu, got {self.hidden_act}")
69+
if self.layer_precision is not None:
70+
if len(self.layer_precision) != self.num_hidden_layers:
71+
raise ValueError(
72+
f"layer_precision must be a list of length {self.num_hidden_layers}, "
73+
f"got {len(self.layer_precision)}"
74+
)
75+
for precision in self.layer_precision:
76+
if precision not in {"fp8", "fp4", None}:
77+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
6378

6479
def save_json(self, path: str):
6580
"""Save config as JSON."""
@@ -142,44 +157,111 @@ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
142157
class CodonFMEncoder(nn.Module):
143158
"""CodonFM encoder using standard TransformerEngine TransformerLayer."""
144159

145-
def __init__(self, config: CodonFMConfig):
160+
def __init__(
161+
self,
162+
config: CodonFMConfig,
163+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
164+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
165+
):
146166
"""Initialize the encoder.
147167
148168
Args:
149169
config: Model configuration.
170+
fp8_recipe: The FP8 recipe for the encoder.
171+
fp4_recipe: The FP4 recipe for the encoder.
150172
"""
151173
super().__init__()
152174
self.config = config
175+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
176+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
177+
178+
if self.config.layer_precision is None:
179+
if fp8_recipe is not None and fp4_recipe is not None:
180+
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
181+
if fp8_recipe is not None:
182+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
183+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
184+
elif fp4_recipe is not None:
185+
raise RuntimeError(
186+
"FP4 recipe provided but no layer_precision configured. "
187+
"Set layer_precision explicitly when using FP4."
188+
)
189+
190+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
191+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
153192

154193
device = "meta" if torch.get_default_device() == torch.device("meta") else "cpu"
155194

156-
self.layers = nn.ModuleList(
157-
[
158-
transformer_engine.pytorch.TransformerLayer(
159-
hidden_size=config.hidden_size,
160-
ffn_hidden_size=config.intermediate_size,
161-
num_attention_heads=config.num_attention_heads,
162-
layernorm_epsilon=config.layer_norm_eps,
163-
hidden_dropout=config.hidden_dropout_prob,
164-
attention_dropout=config.attention_probs_dropout_prob,
165-
qkv_weight_interleaved=config.qkv_weight_interleaved,
166-
layer_number=i + 1,
167-
layer_type="encoder",
168-
self_attn_mask_type="padding",
169-
activation=config.hidden_act,
170-
attn_input_format=config.attn_input_format,
171-
seq_length=config.max_position_embeddings,
172-
num_gqa_groups=config.num_attention_heads,
173-
fuse_qkv_params=config.fuse_qkv_params,
174-
window_size=(-1, -1),
175-
device=device,
195+
layers: list[transformer_engine.pytorch.TransformerLayer] = []
196+
for i in range(config.num_hidden_layers):
197+
with self.get_autocast_context(i, init=True):
198+
layers.append(
199+
transformer_engine.pytorch.TransformerLayer(
200+
hidden_size=config.hidden_size,
201+
ffn_hidden_size=config.intermediate_size,
202+
num_attention_heads=config.num_attention_heads,
203+
layernorm_epsilon=config.layer_norm_eps,
204+
hidden_dropout=config.hidden_dropout_prob,
205+
attention_dropout=config.attention_probs_dropout_prob,
206+
qkv_weight_interleaved=config.qkv_weight_interleaved,
207+
layer_number=i + 1,
208+
layer_type="encoder",
209+
self_attn_mask_type="padding",
210+
activation=config.hidden_act,
211+
attn_input_format=config.attn_input_format,
212+
seq_length=config.max_position_embeddings,
213+
num_gqa_groups=config.num_attention_heads,
214+
fuse_qkv_params=config.fuse_qkv_params,
215+
window_size=(-1, -1),
216+
device=device,
217+
)
176218
)
177-
for i in range(config.num_hidden_layers)
178-
]
179-
)
180219

220+
self.layers = nn.ModuleList(layers)
181221
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
182222

223+
def get_autocast_context(
224+
self, layer_number: int | None, init: bool = False, outer: bool = False
225+
) -> ContextManager:
226+
"""Return the appropriate TE autocast context manager for a given layer.
227+
228+
Handles both the quantized_model_init during layer creation and the te.autocast() during forward.
229+
230+
Args:
231+
layer_number: The 0-indexed layer number.
232+
init: Whether to return a ``quantized_model_init`` context for layer initialization.
233+
outer: Whether to return a global te.autocast() context to wrap the entire encoder stack.
234+
"""
235+
if self.config.layer_precision is None:
236+
return nullcontext()
237+
238+
if outer:
239+
if "fp8" not in self.config.layer_precision:
240+
return nullcontext()
241+
if self._fp8_recipe is None:
242+
warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
243+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe)
244+
245+
precision = self.config.layer_precision[layer_number]
246+
recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision)
247+
248+
if init and self.config.use_quantized_model_init:
249+
if precision == "fp4" and recipe is None:
250+
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
251+
if precision in ("fp8", "fp4"):
252+
return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
253+
return nullcontext()
254+
255+
if precision == "fp8":
256+
if recipe is None:
257+
warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
258+
return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
259+
if precision == "fp4":
260+
if recipe is None:
261+
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
262+
return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
263+
return transformer_engine.pytorch.autocast(enabled=False)
264+
183265
def forward(
184266
self,
185267
hidden_states: torch.Tensor,
@@ -203,23 +285,25 @@ def forward(
203285
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
204286
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
205287

206-
for layer_module in self.layers:
207-
if self.config.attn_input_format == "bshd":
208-
hidden_states = layer_module(
209-
hidden_states,
210-
attention_mask=attention_mask,
211-
rotary_pos_emb=te_rope_emb,
212-
)
213-
else:
214-
hidden_states = layer_module(
215-
hidden_states,
216-
attention_mask=None,
217-
rotary_pos_emb=te_rope_emb,
218-
cu_seqlens_q=kwargs.get("cu_seq_lens_q"),
219-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k"),
220-
max_seqlen_q=kwargs.get("max_length_q"),
221-
max_seqlen_kv=kwargs.get("max_length_k"),
222-
)
288+
with self.get_autocast_context(None, outer=True):
289+
for layer_idx, layer_module in enumerate(self.layers):
290+
with self.get_autocast_context(layer_idx):
291+
if self.config.attn_input_format == "bshd":
292+
hidden_states = layer_module(
293+
hidden_states,
294+
attention_mask=attention_mask,
295+
rotary_pos_emb=te_rope_emb,
296+
)
297+
else:
298+
hidden_states = layer_module(
299+
hidden_states,
300+
attention_mask=None,
301+
rotary_pos_emb=te_rope_emb,
302+
cu_seqlens_q=kwargs.get("cu_seq_lens_q"),
303+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k"),
304+
max_seqlen_q=kwargs.get("max_length_q"),
305+
max_seqlen_kv=kwargs.get("max_length_k"),
306+
)
223307

224308
return hidden_states
225309

@@ -236,18 +320,20 @@ def __init__(self, config: CodonFMConfig):
236320
super().__init__()
237321
device = "meta" if torch.get_default_device() == torch.device("meta") else "cpu"
238322

239-
self.dense = transformer_engine.pytorch.Linear(
240-
config.hidden_size,
241-
config.hidden_size,
242-
device=device,
243-
)
244-
self.layer_norm_linear = transformer_engine.pytorch.LayerNormLinear(
245-
config.hidden_size,
246-
config.vocab_size,
247-
bias=True,
248-
eps=config.layer_norm_eps,
249-
device=device,
250-
)
323+
# Disable quantization for the LM head to avoid numerical instability.
324+
with transformer_engine.pytorch.quantized_model_init(enabled=False):
325+
self.dense = transformer_engine.pytorch.Linear(
326+
config.hidden_size,
327+
config.hidden_size,
328+
device=device,
329+
)
330+
self.layer_norm_linear = transformer_engine.pytorch.LayerNormLinear(
331+
config.hidden_size,
332+
config.vocab_size,
333+
bias=True,
334+
eps=config.layer_norm_eps,
335+
device=device,
336+
)
251337

252338
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253339
"""Forward pass.
@@ -258,25 +344,34 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
258344
Returns:
259345
Logits of shape [..., vocab_size].
260346
"""
261-
x = self.dense(hidden_states)
262-
x = torch.nn.functional.gelu(x)
263-
x = self.layer_norm_linear(x)
347+
# Keep the LM head in higher precision to avoid numerical instability.
348+
with transformer_engine.pytorch.autocast(enabled=False):
349+
x = self.dense(hidden_states)
350+
x = torch.nn.functional.gelu(x)
351+
x = self.layer_norm_linear(x)
264352
return x
265353

266354

267355
class CodonFMForMaskedLM(nn.Module):
268356
"""CodonFM model for masked language modeling with TransformerEngine layers."""
269357

270-
def __init__(self, config: CodonFMConfig):
358+
def __init__(
359+
self,
360+
config: CodonFMConfig,
361+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
362+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
363+
):
271364
"""Initialize the model.
272365
273366
Args:
274367
config: Model configuration.
368+
fp8_recipe: The FP8 recipe for the encoder.
369+
fp4_recipe: The FP4 recipe for the encoder.
275370
"""
276371
super().__init__()
277372
self.config = config
278373
self.embeddings = CodonEmbedding(config)
279-
self.encoder = CodonFMEncoder(config)
374+
self.encoder = CodonFMEncoder(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
280375
self.lm_head = CodonFMLMHead(config)
281376
self._init_weights()
282377

bionemo-recipes/recipes/codonfm_native_te/perf_logger.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import time
2020

21+
import nvdlfw_inspect.api as debug_api
2122
import torch
2223
import torchmetrics
2324
import torchmetrics.text
@@ -74,6 +75,9 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
7475
wandb.init(**args.wandb_init_args, config=self._run_config)
7576
self._progress_bar = tqdm(total=args.num_train_steps, desc="Training")
7677

78+
# Whether to step debug_api.step() after each step
79+
self.quant_stats_config = args.quant_stats_config.enabled
80+
7781
def log_step(
7882
self,
7983
step: int,
@@ -95,6 +99,9 @@ def log_step(
9599
if isinstance(grad_norm, DTensor):
96100
grad_norm = grad_norm.to_local()
97101

102+
if self.quant_stats_config:
103+
debug_api.step()
104+
98105
if step % self.logging_frequency == 0 and step > 0:
99106
num_tokens = batch["input_ids"].numel()
100107
num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != PAD_TOKEN_ID].numel()
@@ -142,6 +149,9 @@ def log_step(
142149

143150
def finish(self):
144151
"""Finish the logger."""
152+
if self.quant_stats_config:
153+
debug_api.end_debug()
154+
145155
if not self._dist_config.is_main_process():
146156
return
147157
wandb.finish()

bionemo-recipes/recipes/codonfm_native_te/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
hydra-core
2+
nvdlfw_inspect
23
pandas
34
pyarrow
45
pytest
6+
pyyaml
57
safetensors
68
torch
79
torchmetrics

0 commit comments

Comments
 (0)