1919
2020import json
2121import math
22+ import warnings
23+ from contextlib import nullcontext
2224from dataclasses import asdict , dataclass
23- from typing import Optional
25+ from typing import ContextManager , Optional
2426
2527import torch
2628import torch .nn as nn
29+ import transformer_engine .common .recipe
2730import transformer_engine .pytorch
2831from torch .nn import CrossEntropyLoss
2932from transformer_engine .pytorch .attention .rope import RotaryPositionEmbedding
@@ -50,6 +53,9 @@ class CodonFMConfig:
5053 # TE-specific options
5154 qkv_weight_interleaved : bool = True
5255 fuse_qkv_params : bool = True
56+ # Layer-wise precision options
57+ layer_precision : list [str | None ] | None = None
58+ use_quantized_model_init : bool = False
5359
5460 def __post_init__ (self ):
5561 """Validate configuration."""
@@ -60,6 +66,15 @@ def __post_init__(self):
6066 )
6167 if self .hidden_act not in ("gelu" , "relu" , "silu" ):
6268 raise ValueError (f"hidden_act must be one of: gelu, relu, silu, got { self .hidden_act } " )
69+ if self .layer_precision is not None :
70+ if len (self .layer_precision ) != self .num_hidden_layers :
71+ raise ValueError (
72+ f"layer_precision must be a list of length { self .num_hidden_layers } , "
73+ f"got { len (self .layer_precision )} "
74+ )
75+ for precision in self .layer_precision :
76+ if precision not in {"fp8" , "fp4" , None }:
77+ raise ValueError (f'layer_precision element must be "fp8", "fp4", or None, got { precision !r} ' )
6378
6479 def save_json (self , path : str ):
6580 """Save config as JSON."""
@@ -142,44 +157,111 @@ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
142157class CodonFMEncoder (nn .Module ):
143158 """CodonFM encoder using standard TransformerEngine TransformerLayer."""
144159
145- def __init__ (self , config : CodonFMConfig ):
160+ def __init__ (
161+ self ,
162+ config : CodonFMConfig ,
163+ fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
164+ fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
165+ ):
146166 """Initialize the encoder.
147167
148168 Args:
149169 config: Model configuration.
170+ fp8_recipe: The FP8 recipe for the encoder.
171+ fp4_recipe: The FP4 recipe for the encoder.
150172 """
151173 super ().__init__ ()
152174 self .config = config
175+ self ._fp8_recipe : transformer_engine .common .recipe .Recipe | None = fp8_recipe
176+ self ._fp4_recipe : transformer_engine .common .recipe .Recipe | None = fp4_recipe
177+
178+ if self .config .layer_precision is None :
179+ if fp8_recipe is not None and fp4_recipe is not None :
180+ raise RuntimeError ("Both FP8 and FP4 recipes provided, but no layer precision provided." )
181+ if fp8_recipe is not None :
182+ warnings .warn ("No layer precision provided, using FP8 recipe for all layers." , UserWarning )
183+ self .config .layer_precision = ["fp8" ] * self .config .num_hidden_layers
184+ elif fp4_recipe is not None :
185+ raise RuntimeError (
186+ "FP4 recipe provided but no layer_precision configured. "
187+ "Set layer_precision explicitly when using FP4."
188+ )
189+
190+ if self .config .layer_precision is not None and "fp4" in self .config .layer_precision and fp4_recipe is None :
191+ raise RuntimeError ("layer_precision contains 'fp4' entries but no fp4_recipe was provided." )
153192
154193 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cpu"
155194
156- self .layers = nn .ModuleList (
157- [
158- transformer_engine .pytorch .TransformerLayer (
159- hidden_size = config .hidden_size ,
160- ffn_hidden_size = config .intermediate_size ,
161- num_attention_heads = config .num_attention_heads ,
162- layernorm_epsilon = config .layer_norm_eps ,
163- hidden_dropout = config .hidden_dropout_prob ,
164- attention_dropout = config .attention_probs_dropout_prob ,
165- qkv_weight_interleaved = config .qkv_weight_interleaved ,
166- layer_number = i + 1 ,
167- layer_type = "encoder" ,
168- self_attn_mask_type = "padding" ,
169- activation = config .hidden_act ,
170- attn_input_format = config .attn_input_format ,
171- seq_length = config .max_position_embeddings ,
172- num_gqa_groups = config .num_attention_heads ,
173- fuse_qkv_params = config .fuse_qkv_params ,
174- window_size = (- 1 , - 1 ),
175- device = device ,
195+ layers : list [transformer_engine .pytorch .TransformerLayer ] = []
196+ for i in range (config .num_hidden_layers ):
197+ with self .get_autocast_context (i , init = True ):
198+ layers .append (
199+ transformer_engine .pytorch .TransformerLayer (
200+ hidden_size = config .hidden_size ,
201+ ffn_hidden_size = config .intermediate_size ,
202+ num_attention_heads = config .num_attention_heads ,
203+ layernorm_epsilon = config .layer_norm_eps ,
204+ hidden_dropout = config .hidden_dropout_prob ,
205+ attention_dropout = config .attention_probs_dropout_prob ,
206+ qkv_weight_interleaved = config .qkv_weight_interleaved ,
207+ layer_number = i + 1 ,
208+ layer_type = "encoder" ,
209+ self_attn_mask_type = "padding" ,
210+ activation = config .hidden_act ,
211+ attn_input_format = config .attn_input_format ,
212+ seq_length = config .max_position_embeddings ,
213+ num_gqa_groups = config .num_attention_heads ,
214+ fuse_qkv_params = config .fuse_qkv_params ,
215+ window_size = (- 1 , - 1 ),
216+ device = device ,
217+ )
176218 )
177- for i in range (config .num_hidden_layers )
178- ]
179- )
180219
220+ self .layers = nn .ModuleList (layers )
181221 self .rotary_embeddings = RotaryPositionEmbedding (config .hidden_size // config .num_attention_heads )
182222
223+ def get_autocast_context (
224+ self , layer_number : int | None , init : bool = False , outer : bool = False
225+ ) -> ContextManager :
226+ """Return the appropriate TE autocast context manager for a given layer.
227+
228+ Handles both the quantized_model_init during layer creation and the te.autocast() during forward.
229+
230+ Args:
231+ layer_number: The 0-indexed layer number.
232+ init: Whether to return a ``quantized_model_init`` context for layer initialization.
233+ outer: Whether to return a global te.autocast() context to wrap the entire encoder stack.
234+ """
235+ if self .config .layer_precision is None :
236+ return nullcontext ()
237+
238+ if outer :
239+ if "fp8" not in self .config .layer_precision :
240+ return nullcontext ()
241+ if self ._fp8_recipe is None :
242+ warnings .warn ("No FP8 recipe provided, using default recipe." , UserWarning )
243+ return transformer_engine .pytorch .autocast (enabled = True , recipe = self ._fp8_recipe )
244+
245+ precision = self .config .layer_precision [layer_number ]
246+ recipe = {"fp8" : self ._fp8_recipe , "fp4" : self ._fp4_recipe }.get (precision )
247+
248+ if init and self .config .use_quantized_model_init :
249+ if precision == "fp4" and recipe is None :
250+ raise RuntimeError ("No FP4 recipe provided, but layer precision is set to FP4." )
251+ if precision in ("fp8" , "fp4" ):
252+ return transformer_engine .pytorch .quantized_model_init (recipe = recipe )
253+ return nullcontext ()
254+
255+ if precision == "fp8" :
256+ if recipe is None :
257+ warnings .warn ("No FP8 recipe provided, using default recipe." , UserWarning )
258+ return transformer_engine .pytorch .autocast (enabled = True , recipe = recipe )
259+ if precision == "fp4" :
260+ if recipe is None :
261+ raise RuntimeError ("No FP4 recipe provided, but layer precision is set to FP4." )
262+ return transformer_engine .pytorch .autocast (enabled = True , recipe = recipe )
263+ return transformer_engine .pytorch .autocast (enabled = False )
264+
183265 def forward (
184266 self ,
185267 hidden_states : torch .Tensor ,
@@ -203,23 +285,25 @@ def forward(
203285 te_rope_emb = self .rotary_embeddings (max_seq_len = self .config .max_position_embeddings )
204286 te_rope_emb = te_rope_emb .to (hidden_states .device , non_blocking = True )
205287
206- for layer_module in self .layers :
207- if self .config .attn_input_format == "bshd" :
208- hidden_states = layer_module (
209- hidden_states ,
210- attention_mask = attention_mask ,
211- rotary_pos_emb = te_rope_emb ,
212- )
213- else :
214- hidden_states = layer_module (
215- hidden_states ,
216- attention_mask = None ,
217- rotary_pos_emb = te_rope_emb ,
218- cu_seqlens_q = kwargs .get ("cu_seq_lens_q" ),
219- cu_seqlens_kv = kwargs .get ("cu_seq_lens_k" ),
220- max_seqlen_q = kwargs .get ("max_length_q" ),
221- max_seqlen_kv = kwargs .get ("max_length_k" ),
222- )
288+ with self .get_autocast_context (None , outer = True ):
289+ for layer_idx , layer_module in enumerate (self .layers ):
290+ with self .get_autocast_context (layer_idx ):
291+ if self .config .attn_input_format == "bshd" :
292+ hidden_states = layer_module (
293+ hidden_states ,
294+ attention_mask = attention_mask ,
295+ rotary_pos_emb = te_rope_emb ,
296+ )
297+ else :
298+ hidden_states = layer_module (
299+ hidden_states ,
300+ attention_mask = None ,
301+ rotary_pos_emb = te_rope_emb ,
302+ cu_seqlens_q = kwargs .get ("cu_seq_lens_q" ),
303+ cu_seqlens_kv = kwargs .get ("cu_seq_lens_k" ),
304+ max_seqlen_q = kwargs .get ("max_length_q" ),
305+ max_seqlen_kv = kwargs .get ("max_length_k" ),
306+ )
223307
224308 return hidden_states
225309
@@ -236,18 +320,20 @@ def __init__(self, config: CodonFMConfig):
236320 super ().__init__ ()
237321 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cpu"
238322
239- self .dense = transformer_engine .pytorch .Linear (
240- config .hidden_size ,
241- config .hidden_size ,
242- device = device ,
243- )
244- self .layer_norm_linear = transformer_engine .pytorch .LayerNormLinear (
245- config .hidden_size ,
246- config .vocab_size ,
247- bias = True ,
248- eps = config .layer_norm_eps ,
249- device = device ,
250- )
323+ # Disable quantization for the LM head to avoid numerical instability.
324+ with transformer_engine .pytorch .quantized_model_init (enabled = False ):
325+ self .dense = transformer_engine .pytorch .Linear (
326+ config .hidden_size ,
327+ config .hidden_size ,
328+ device = device ,
329+ )
330+ self .layer_norm_linear = transformer_engine .pytorch .LayerNormLinear (
331+ config .hidden_size ,
332+ config .vocab_size ,
333+ bias = True ,
334+ eps = config .layer_norm_eps ,
335+ device = device ,
336+ )
251337
252338 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
253339 """Forward pass.
@@ -258,25 +344,34 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
258344 Returns:
259345 Logits of shape [..., vocab_size].
260346 """
261- x = self .dense (hidden_states )
262- x = torch .nn .functional .gelu (x )
263- x = self .layer_norm_linear (x )
347+ # Keep the LM head in higher precision to avoid numerical instability.
348+ with transformer_engine .pytorch .autocast (enabled = False ):
349+ x = self .dense (hidden_states )
350+ x = torch .nn .functional .gelu (x )
351+ x = self .layer_norm_linear (x )
264352 return x
265353
266354
267355class CodonFMForMaskedLM (nn .Module ):
268356 """CodonFM model for masked language modeling with TransformerEngine layers."""
269357
270- def __init__ (self , config : CodonFMConfig ):
358+ def __init__ (
359+ self ,
360+ config : CodonFMConfig ,
361+ fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
362+ fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
363+ ):
271364 """Initialize the model.
272365
273366 Args:
274367 config: Model configuration.
368+ fp8_recipe: The FP8 recipe for the encoder.
369+ fp4_recipe: The FP4 recipe for the encoder.
275370 """
276371 super ().__init__ ()
277372 self .config = config
278373 self .embeddings = CodonEmbedding (config )
279- self .encoder = CodonFMEncoder (config )
374+ self .encoder = CodonFMEncoder (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
280375 self .lm_head = CodonFMLMHead (config )
281376 self ._init_weights ()
282377
0 commit comments