From 8c94f85cdf8908d2c67855861e09aa2c060a4845 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 12:28:55 -0800 Subject: [PATCH 01/17] tries to get fp4 working Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/defaults.yaml | 8 + .../recipes/esm2_native_te/modeling_esm_te.py | 657 ++++++++++++++++++ .../recipes/esm2_native_te/train_fsdp2.py | 33 +- 3 files changed, 684 insertions(+), 14 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index baace7c80..87c624aca 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -51,6 +51,14 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + fp4_model_init_kwargs: + enabled: false # If this is set to true, fp4_config.enabled must also be set to true. + # Optimizer config adamw_kwargs: lr: 4e-4 diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 000000000..a12e3ef32 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,657 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + for layer_module in self.layers: + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + """Override the default get_init_context method to allow for fp8 model initialization.""" + return [] + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys = ("lm_head.decoder.weight",) + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.init_weights() + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 28409e0c1..174cfcdd4 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -29,6 +29,8 @@ from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + # This import seems to be needed with meta device init and AutoModel.from_config from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 @@ -57,12 +59,6 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled and not args.fp8_config.enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" - ) - if args.fp8_stats_config.enabled: fp8_stats_file = args.fp8_stats_config.fp8_stats_file fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" @@ -84,12 +80,18 @@ def main(args: DictConfig) -> float | None: ) # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) - + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + elif args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + else: + print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -99,9 +101,9 @@ def main(args: DictConfig) -> float | None: # versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), - transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + # model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) @@ -164,8 +166,11 @@ def main(args: DictConfig) -> float | None: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + fp_context = transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe) if args.fp8_config.enabled else nullcontext() + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp4_recipe) if args.fp4_config.enabled else fp_context + # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + with fp_context: outputs = model(**batch) # Backward pass. From 8514d7bc5fca9676dfec990666de3f536a211459 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 12:50:14 -0800 Subject: [PATCH 02/17] refactors fp8 stats logs Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp8_debugging_stats.yaml | 7 ++++++- .../esm2_native_te/hydra_config/defaults.yaml | 7 ++++--- .../recipes/esm2_native_te/train_fsdp2.py | 20 +++++++++---------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml index 7544bbedc..9653d8a04 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad, fprop] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 87c624aca..a8e6ac88a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -84,7 +84,8 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: + +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 174cfcdd4..1bb569541 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -59,16 +59,16 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - if args.fp8_stats_config.enabled: - fp8_stats_file = args.fp8_stats_config.fp8_stats_file - fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") + if args.quant_stats_config.enabled: + quant_stats_file = args.quant_stats_config.quant_stats_file + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" + quant_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {quant_log_dir}") te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") debug_api.initialize( - config_file=fp8_stats_file, + config_file=quant_stats_file, feature_dirs=[te_features_dir], - log_dir=fp8_log_dir, + log_dir=quant_log_dir, default_logging_enabled=True, ) @@ -125,7 +125,7 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). @@ -184,7 +184,7 @@ def main(args: DictConfig) -> float | None: optimizer.step() scheduler.step() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.step() optimizer.zero_grad() @@ -228,7 +228,7 @@ def main(args: DictConfig) -> float | None: # Clean up distributed training perf_logger.finish() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.end_debug() torch.distributed.destroy_process_group() From 09840d059ed51facbafb16586447cd2b256d7ebc Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 23 Jan 2026 15:59:09 -0800 Subject: [PATCH 03/17] fp4 debugging stats yaml Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml new file mode 100644 index 000000000..81d6f4a42 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,24 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Match the actual linear layers within attention that support FP8 stats + layer_types: [layernorm_qkv, proj, fc1, fc2] + transformer_engine: + # Uncomment once https://github.com/NVIDIA/TransformerEngine/pull/2296 is merged. + # LogFp4TensorStats: + # enabled: True + # tensors_struct: + # - tensor: activation + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + # - tensor: gradient + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + # - tensor: weight + # stats: [underflows%, scale_inv_min, scale_inv_max, mse] + # freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? + tensors: [dgrad, wgrad, fprop] + freq: 1 From 0f62e9885c4e29b52499484476294d864d43eba3 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Sat, 24 Jan 2026 11:18:23 -0800 Subject: [PATCH 04/17] BF16 last 6 layers Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/modeling_esm_te.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index a12e3ef32..0a7ed6e5b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -41,7 +41,7 @@ from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel from transformers.utils import logging from transformers.utils.generic import TransformersKwargs - +from contextlib import nullcontext logger = logging.get_logger(__name__) @@ -199,22 +199,36 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + # Set some layers to BF16. (28-33) (This will be from a config later). + # TODO: Also make sure this is only for FP4, not FP8 + layers_to_bf16 = {self.layers[-1], + self.layers[-2], + self.layers[-3], + self.layers[-4], + self.layers[-5], + self.layers[-6]} for layer_module in self.layers: + if layer_module in layers_to_bf16: + fp_context = transformer_engine.pytorch.autocast(enabled=False) + else: + fp_context = nullcontext() + if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + with fp_context: + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) hidden_states = self.emb_layer_norm_after(hidden_states) From 9c32457e92e6294b51cc8f5c3cf5a0bf4eedaeb7 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 26 Jan 2026 10:47:05 -0800 Subject: [PATCH 05/17] sets bf16 layers thru cli Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 9 +++++++++ .../esm2_native_te/hydra_config/defaults.yaml | 2 ++ .../recipes/esm2_native_te/modeling_esm_te.py | 13 ++++++------- .../recipes/esm2_native_te/train_fsdp2.py | 3 ++- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fd027601d..fae42a0b5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -17,3 +17,12 @@ wandb_init_args: checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" + +# Layers explicitly set to BF16 in case of NVFP4 training. +bf16_layers: + - 27 + - 28 + - 29 + - 30 + - 31 + - 32 \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index a8e6ac88a..6f6cf8534 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -89,3 +89,5 @@ quant_stats_config: enabled: false quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats + +bf16_layers: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 0a7ed6e5b..f59516ff7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + bf16_layers: Optional[list[int]] = None, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -111,7 +112,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type - + self.bf16_layers = bf16_layers # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -201,12 +202,10 @@ def forward( # Set some layers to BF16. (28-33) (This will be from a config later). # TODO: Also make sure this is only for FP4, not FP8 - layers_to_bf16 = {self.layers[-1], - self.layers[-2], - self.layers[-3], - self.layers[-4], - self.layers[-5], - self.layers[-6]} + layers_to_bf16 = set() + if self.config.bf16_layers is not None: + layers_to_bf16 = set(self.layers[layer_idx] for layer_idx in self.config.bf16_layers) + for layer_module in self.layers: if layer_module in layers_to_bf16: fp_context = transformer_engine.pytorch.autocast(enabled=False) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 1bb569541..bbde4f1b6 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -91,7 +91,8 @@ def main(args: DictConfig) -> float | None: else: print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + bf16_layers = OmegaConf.to_container(args.bf16_layers, resolve=True) if args.bf16_layers is not None and args.fp4_config.enabled else None + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, bf16_layers=bf16_layers) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" From 196d0c7f0a8b064f5b886a7cd55a8b7ef190433b Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 26 Jan 2026 11:38:20 -0800 Subject: [PATCH 06/17] donwgrade te version Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index b18607fd7..0602ca8a8 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers +transformers==4.57.3 wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect From 265911da3e16f624becba32c4b88bb0a6528e416 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 14:00:22 -0800 Subject: [PATCH 07/17] layer specific autocast Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 12 +++++-- .../esm2_native_te/hydra_config/defaults.yaml | 4 ++- .../recipes/esm2_native_te/modeling_esm_te.py | 20 ++++++------ .../recipes/esm2_native_te/train_ddp.py | 2 +- .../recipes/esm2_native_te/train_fsdp2.py | 32 +++++++++++++------ 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fae42a0b5..f6108d573 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -19,10 +19,18 @@ checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" # Layers explicitly set to BF16 in case of NVFP4 training. -bf16_layers: +fp8_layers: - 27 - 28 - 29 - 30 - 31 - - 32 \ No newline at end of file + - 32 + +fp4_layers: + - 0 + - 14 + - 15 + - 16 + +use_fp32_optimizer_weights: true \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 6f6cf8534..da6f9f47c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -90,4 +90,6 @@ quant_stats_config: quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats -bf16_layers: null \ No newline at end of file +fp8_layers: null +fp4_layers: null +use_fp32_optimizer_weights: false \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index f59516ff7..b2d1a75df 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -166,6 +166,7 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self.layer_number_quantized_recipe_map = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -200,20 +201,17 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - # Set some layers to BF16. (28-33) (This will be from a config later). - # TODO: Also make sure this is only for FP4, not FP8 - layers_to_bf16 = set() - if self.config.bf16_layers is not None: - layers_to_bf16 = set(self.layers[layer_idx] for layer_idx in self.config.bf16_layers) - - for layer_module in self.layers: - if layer_module in layers_to_bf16: - fp_context = transformer_engine.pytorch.autocast(enabled=False) - else: - fp_context = nullcontext() + # Utilize the layer number quantized recipe map to determine the context for each layer. + for layer_number, layer_module in enumerate(self.layers): + fp_recipe = self.layer_number_quantized_recipe_map[layer_number] if layer_number in self.layer_number_quantized_recipe_map else None if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) + + if fp_recipe is not None: + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) + else: + fp_context = nullcontext() with fp_context: hidden_states = layer_module( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 1027703f3..002cbb94d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -113,7 +113,7 @@ def main(args: DictConfig) -> float | None: device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, device_mesh=device_mesh["ddp"], - ) + ) #TODO: Try BF16 compute weights with FP32 master weights here. # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. train_dataloader, dataset_or_sampler = ( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index bbde4f1b6..d8a377c96 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -84,15 +84,16 @@ def main(args: DictConfig) -> float | None: fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - elif args.fp4_config.enabled: + + if args.fp4_config.enabled: fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - else: - print("No FP8 or FP4 config enabled, using default bfloat16") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - bf16_layers = OmegaConf.to_container(args.bf16_layers, resolve=True) if args.bf16_layers is not None and args.fp4_config.enabled else None - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16, bf16_layers=bf16_layers) + fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None + fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None + + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -112,9 +113,21 @@ def main(args: DictConfig) -> float | None: transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"]) # TODO: Update mixed precision policy to set it to FP#2 fully_shard(model, mesh=device_mesh["dp"]) + # Create a layer map for the transformer stack. + layer_number_quantized_recipe_map = {} + for layer_number, layer in enumerate(transformer_stack): + + if layer_number in fp8_layers: + layer_number_quantized_recipe_map[layer_number] = fp8_recipe + elif layer_number in fp4_layers: + layer_number_quantized_recipe_map[layer_number] = fp4_recipe + else: + layer_number_quantized_recipe_map[layer_number] = None + + model.esm.encoder.layer_number_quantized_recipe_map = layer_number_quantized_recipe_map # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: @@ -167,11 +180,12 @@ def main(args: DictConfig) -> float | None: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - fp_context = transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe) if args.fp8_config.enabled else nullcontext() - fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp4_recipe) if args.fp4_config.enabled else fp_context + # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting # Forward pass with mixed precision. - with fp_context: + # Make the FP context just MXFP8. Then use NVFP4 for certain layers. + # with fp_context: #TODO: I think I can get rid of this, and just do it inside forward. + with transformer_engine.pytorch.autocast(): outputs = model(**batch) # Backward pass. From c5e472bde5aacb6ef84bf19f0c5cf7b3d24ca5fb Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 15:01:47 -0800 Subject: [PATCH 08/17] enables layer specific fp recipes Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/hydra_config/L1_650M.yaml | 24 ++++++++++++++++++- .../recipes/esm2_native_te/modeling_esm_te.py | 18 ++++++++++++-- .../recipes/esm2_native_te/train_fsdp2.py | 8 ++----- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index f6108d573..e39c4b398 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -20,6 +20,14 @@ checkpoint: # Layers explicitly set to BF16 in case of NVFP4 training. fp8_layers: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 - 27 - 28 - 29 @@ -28,9 +36,23 @@ fp8_layers: - 32 fp4_layers: - - 0 + - 9 + - 10 + - 11 + - 12 + - 13 - 14 - 15 - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 use_fp32_optimizer_weights: true \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index b2d1a75df..2053cdd4b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -37,6 +37,7 @@ MaskedLMOutput, TokenClassifierOutput, ) +import transformer_engine.common.recipe from transformers.models.esm.configuration_esm import EsmConfig from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel from transformers.utils import logging @@ -54,6 +55,13 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = (transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling) +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling) + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -208,10 +216,16 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - if fp_recipe is not None: + # If BF16 desired --> use autocast(false) so it goes to BF16. + # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. + # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. + if isinstance(fp_recipe, FP8_RECIPES): + fp_context = nullcontext() + elif isinstance(fp_recipe, FP4_RECIPES): fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) else: - fp_context = nullcontext() + fp_context = transformer_engine.pytorch.autocast(enabled=False) + # TODO(@jomitchell): Double check that this works, make a funciton for it then unit test it. with fp_context: hidden_states = layer_module( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index d8a377c96..0d25cb94c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -179,13 +179,9 @@ def main(args: DictConfig) -> float | None: while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - - # Note: FOr NVFP4 it looks like its just autocast? https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting - # Forward pass with mixed precision. - # Make the FP context just MXFP8. Then use NVFP4 for certain layers. - # with fp_context: #TODO: I think I can get rid of this, and just do it inside forward. - with transformer_engine.pytorch.autocast(): + # Use an outer FP8 recipe. + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): outputs = model(**batch) # Backward pass. From 2e2229cee53a53fccf68c4fa8b398aadbd068f9f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 15:35:44 -0800 Subject: [PATCH 09/17] adds fp32 optim weights with bf16 compute weights Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/train_fsdp2.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 0d25cb94c..3a5a3c4e9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -22,10 +22,12 @@ import torch import transformer_engine import transformer_engine.pytorch + from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW + from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM @@ -93,7 +95,7 @@ def main(args: DictConfig) -> float | None: fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -112,9 +114,14 @@ def main(args: DictConfig) -> float | None: # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward + reduce_dtype=torch.bfloat16, # Gradient reductions in BF16 + output_dtype=torch.bfloat16, # Forward output dtype + ) for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) # TODO: Update mixed precision policy to set it to FP#2 - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) # TODO: Update mixed precision policy to set it to FP#2 + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} From 8acd2a94eb27c19f112367b1c97218350e7781a4 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 16:00:25 -0800 Subject: [PATCH 10/17] enables grad reduce in fp32 for better precision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 3a5a3c4e9..c665308eb 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -116,11 +116,11 @@ def main(args: DictConfig) -> float | None: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward - reduce_dtype=torch.bfloat16, # Gradient reductions in BF16 + reduce_dtype=torch.float32, # Gradient reductions in FP32 output_dtype=torch.bfloat16, # Forward output dtype ) for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) # TODO: Update mixed precision policy to set it to FP#2 + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # Create a layer map for the transformer stack. From 589479d2eac54d8e947be444e9c819230a7a93ee Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 2 Feb 2026 17:38:15 -0800 Subject: [PATCH 11/17] adds FusedAdam for fun Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index c665308eb..861133f8f 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -28,6 +28,7 @@ from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW +from transformer_engine.pytorch.optimizers import FusedAdam from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM @@ -151,6 +152,15 @@ def main(args: DictConfig) -> float | None: # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + # optimizer = FusedAdam(model.parameters(), + # lr=4e-4, + # betas=(0.9, 0.98), + # eps=1e-8, + # weight_decay=0.01, + # master_weights=True, + # master_weight_dtype=torch.float32, + # ) + # Note: Got an error about mixed torch.Tensor and DTensor here, so using AdamW instead. scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. From c978c495b4f1fe9ef5ceda1015a263d7e2638dc2 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 15:58:42 -0800 Subject: [PATCH 12/17] fixes up debugging yaml and adds dockerfile for te tot Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/Dockerfile | 15 ++++++++++++++- .../esm2_native_te/fp4_debugging_stats.yaml | 17 +++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index b940874af..c388ddf56 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,9 +1,22 @@ # syntax=docker/dockerfile:1.4 FROM nvcr.io/nvidia/pytorch:25.12-py3 +# Install sccache for faster builds +RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update \ + && apt-get install -y sccache \ + && rm -rf /var/lib/apt/lists/* + +# Uninstall pre-installed Transformer Engine and install from source +RUN pip uninstall -y transformer-engine && \ + NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ + pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main + +# Install BioNeMo requirements +WORKDIR /workspace/bionemo RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ PIP_CONSTRAINT= pip install -r /requirements.txt -WORKDIR /workspace/bionemo COPY . . diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 81d6f4a42..83de92764 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -4,8 +4,7 @@ example_fp8_tensor_stat_collection: # Match the actual linear layers within attention that support FP8 stats layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: - # Uncomment once https://github.com/NVIDIA/TransformerEngine/pull/2296 is merged. - # LogFp4TensorStats: + # LogFp8TensorStats: # enabled: True # tensors_struct: # - tensor: activation @@ -14,11 +13,17 @@ example_fp8_tensor_stat_collection: # - tensor: gradient # stats: [underflows%, scale_inv_min, scale_inv_max, mse] # freq: 10 - # - tensor: weight - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 10 + - tensor: gradient + stats: [underflows%, mse] + freq: 10 LogTensorStats: enabled: True stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? tensors: [dgrad, wgrad, fprop] - freq: 1 + freq: 10 From c353607b9ad43a04da6eb100179b98079f1a2eab Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 16:31:09 -0800 Subject: [PATCH 13/17] inject layer regex patterns for fp4 fp8 Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 38 +++++---- .../esm2_native_te/hydra_config/L1_650M.yaml | 3 +- .../esm2_native_te/hydra_config/defaults.yaml | 1 + .../recipes/esm2_native_te/train_fsdp2.py | 80 ++++++++++++++++++- 4 files changed, 101 insertions(+), 21 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 83de92764..480be54dd 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -1,18 +1,10 @@ -example_fp8_tensor_stat_collection: +example_fp4_tensor_stat_collection: enabled: True layers: - # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv, proj, fc1, fc2] + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' transformer_engine: - # LogFp8TensorStats: - # enabled: True - # tensors_struct: - # - tensor: activation - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 - # - tensor: gradient - # stats: [underflows%, scale_inv_min, scale_inv_max, mse] - # freq: 10 LogNvfp4TensorStats: enabled: True tensors_struct: @@ -22,8 +14,20 @@ example_fp8_tensor_stat_collection: - tensor: gradient stats: [underflows%, mse] freq: 10 - LogTensorStats: - enabled: True - stats: [max, min, mean, std, l1_norm] #TODO: Can you get underflows% here? - tensors: [dgrad, wgrad, fprop] - freq: 10 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 + - tensor: gradient + stats: [underflows%, scale_inv_min, scale_inv_max, mse] + freq: 10 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index e39c4b398..d71fc0375 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -18,7 +18,7 @@ wandb_init_args: checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" -# Layers explicitly set to BF16 in case of NVFP4 training. +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: - 1 - 2 @@ -34,6 +34,7 @@ fp8_layers: - 30 - 31 - 32 + - 33 fp4_layers: - 9 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index da6f9f47c..00287ee83 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -90,6 +90,7 @@ quant_stats_config: quant_stats_file: ./fp8_debugging_stats.yaml quant_log_dir: ./log_quant_stats +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null use_fp32_optimizer_weights: false \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 861133f8f..6746aaaaa 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import tempfile from contextlib import nullcontext from pathlib import Path @@ -22,6 +23,7 @@ import torch import transformer_engine import transformer_engine.pytorch +import yaml from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh @@ -48,6 +50,65 @@ logger.setLevel(logging.INFO) +def generate_layer_regex(layer_numbers: list[int]) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in logs). + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return "" + # Use alternation for arbitrary layer lists: (1|2|3|4|5) + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (may be a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + # Update FP4 section if it exists and fp4_layers is provided + if fp4_layers and "example_fp4_tensor_stat_collection" in config: + fp4_regex = generate_layer_regex(fp4_layers) + config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + + # Update FP8 section if it exists and fp8_layers is provided + if fp8_layers and "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + + # Write to a temp file to avoid modifying the original + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + # Log the updated config for visibility + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train ESM-2 with TE layers using fsdp2. @@ -62,8 +123,24 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) + # Parse layer lists first (1-indexed from args, used for both quant stats and internal recipe mapping) + fp8_layers_1indexed = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None + fp4_layers_1indexed = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None + + # Convert to 0-indexed for internal use + fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed else None + fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed else None + if args.quant_stats_config.enabled: quant_stats_file = args.quant_stats_config.quant_stats_file + + # Update the quant stats config with layer-specific regex patterns (using 1-indexed layer numbers) + quant_stats_file = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, + ) + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" quant_log_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Logging quant stats to {quant_log_dir}") @@ -92,9 +169,6 @@ def main(args: DictConfig) -> float | None: fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - fp8_layers = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None - fp4_layers = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. From 879c50d42dadae6d81e87ef5cbc6662a8ceba1d8 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 3 Feb 2026 17:40:27 -0800 Subject: [PATCH 14/17] enables fp4 layer with nothing Signed-off-by: Jonathan Mitchell --- .../esm2_native_te/fp4_debugging_stats.yaml | 4 +- .../recipes/esm2_native_te/train_fsdp2.py | 50 +++++++++++++------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index 480be54dd..d2e9da08e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -26,8 +26,8 @@ example_fp8_tensor_stat_collection: enabled: True tensors_struct: - tensor: activation - stats: [underflows%, scale_inv_min, scale_inv_max, mse] + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] freq: 10 - tensor: gradient - stats: [underflows%, scale_inv_min, scale_inv_max, mse] + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] freq: 10 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 6746aaaaa..f089b0db1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -50,17 +50,19 @@ logger.setLevel(logging.INFO) -def generate_layer_regex(layer_numbers: list[int]) -> str: +def generate_layer_regex(layer_numbers: list[int] | None) -> str: """Generate a regex pattern to match specific layer numbers (1-indexed). Args: layer_numbers: List of layer numbers (1-indexed, as shown in logs). + If empty or None, returns a pattern that matches nothing. Returns: Regex pattern string for matching those layers' linear sublayers. """ if not layer_numbers: - return "" + # Return a pattern that matches nothing (non-existent layer) + return r"model\.esm\.encoder\.layers\.DISABLED_NO_LAYERS_SPECIFIED" # Use alternation for arbitrary layer lists: (1|2|3|4|5) layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" @@ -80,21 +82,40 @@ def update_quant_stats_config( Returns: Path to the updated config file (may be a temp file). + + Raises: + ValueError: If fp4_layers and fp8_layers have overlapping layer numbers. """ + # Check for overlapping layers + fp4_set = set(fp4_layers) if fp4_layers else set() + fp8_set = set(fp8_layers) if fp8_layers else set() + overlap = fp4_set & fp8_set + if overlap: + raise ValueError( + f"fp4_layers and fp8_layers cannot have overlapping layer numbers. " + f"Found overlap: {sorted(overlap)}" + ) + with open(config_file, "r") as f: config = yaml.safe_load(f) - # Update FP4 section if it exists and fp4_layers is provided - if fp4_layers and "example_fp4_tensor_stat_collection" in config: + # Update FP4 section if it exists (always update, even if empty to disable matching) + if "example_fp4_tensor_stat_collection" in config: fp4_regex = generate_layer_regex(fp4_layers) config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex - logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + if fp4_layers: + logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + else: + logger.info("FP4 layers empty - regex set to match nothing") - # Update FP8 section if it exists and fp8_layers is provided - if fp8_layers and "example_fp8_tensor_stat_collection" in config: + # Update FP8 section if it exists (always update, even if empty to disable matching) + if "example_fp8_tensor_stat_collection" in config: fp8_regex = generate_layer_regex(fp8_layers) config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex - logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") # Write to a temp file to avoid modifying the original temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) @@ -127,9 +148,9 @@ def main(args: DictConfig) -> float | None: fp8_layers_1indexed = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None fp4_layers_1indexed = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None - # Convert to 0-indexed for internal use - fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed else None - fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed else None + # Convert to 0-indexed for internal use (use 'is not None' to handle empty lists correctly) + fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed is not None else None + fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed is not None else None if args.quant_stats_config.enabled: quant_stats_file = args.quant_stats_config.quant_stats_file @@ -200,11 +221,12 @@ def main(args: DictConfig) -> float | None: # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() for layer_number, layer in enumerate(transformer_stack): - - if layer_number in fp8_layers: + if layer_number in fp8_layers_set: layer_number_quantized_recipe_map[layer_number] = fp8_recipe - elif layer_number in fp4_layers: + elif layer_number in fp4_layers_set: layer_number_quantized_recipe_map[layer_number] = fp4_recipe else: layer_number_quantized_recipe_map[layer_number] = None From 5154fea683c0c500734ef7b33dbed1a7de25075f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 4 Feb 2026 13:37:43 -0800 Subject: [PATCH 15/17] pins autotokenizer to previous revision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/dataset.py | 2 +- .../recipes/esm2_native_te/fp4_debugging_stats.yaml | 8 ++++---- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 9 ++++++++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index c915f30ea..8b8c0f06b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -56,7 +56,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision="d81c2e5aec37b5e794d0482e3996fb045a137792") def tokenize_function(examples): """Tokenize the protein sequences.""" diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml index d2e9da08e..d56739a6a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -10,10 +10,10 @@ example_fp4_tensor_stat_collection: tensors_struct: - tensor: activation stats: [underflows%, mse] - freq: 10 + freq: 100 - tensor: gradient stats: [underflows%, mse] - freq: 10 + freq: 100 example_fp8_tensor_stat_collection: enabled: True @@ -27,7 +27,7 @@ example_fp8_tensor_stat_collection: tensors_struct: - tensor: activation stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] - freq: 10 + freq: 100 - tensor: gradient stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] - freq: 10 + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index f089b0db1..be62b7f5d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -17,6 +17,7 @@ import tempfile from contextlib import nullcontext from pathlib import Path +from torch.profiler import profile, ProfilerActivity import hydra import nvdlfw_inspect.api as debug_api @@ -294,8 +295,14 @@ def main(args: DictConfig) -> float | None: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 # Use an outer FP8 recipe. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None): outputs = model(**batch) + + # if step == 5: # Profile step 5 + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + # outputs = model(**batch) + # logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) # Backward pass. loss = outputs.loss From d880fc93f09632688e2504f666cd2d1676993f3b Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Wed, 4 Feb 2026 18:32:28 -0800 Subject: [PATCH 16/17] adds Dockerfile.te_testing for TE build from src but bad perf on it Signed-off-by: Jonathan Mitchell --- .../recipes/esm2_native_te/Dockerfile | 17 ++------------ .../esm2_native_te/Dockerfile.te_testing | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index c388ddf56..71a793b1b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,22 +1,9 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:25.12-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 -# Install sccache for faster builds -RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ - --mount=type=cache,target=/var/cache/apt,sharing=locked \ - apt-get update \ - && apt-get install -y sccache \ - && rm -rf /var/lib/apt/lists/* - -# Uninstall pre-installed Transformer Engine and install from source -RUN pip uninstall -y transformer-engine && \ - NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ - pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main - -# Install BioNeMo requirements -WORKDIR /workspace/bionemo RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ PIP_CONSTRAINT= pip install -r /requirements.txt +WORKDIR /workspace/bionemo COPY . . diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing new file mode 100644 index 000000000..c388ddf56 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing @@ -0,0 +1,22 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:25.12-py3 + +# Install sccache for faster builds +RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update \ + && apt-get install -y sccache \ + && rm -rf /var/lib/apt/lists/* + +# Uninstall pre-installed Transformer Engine and install from source +RUN pip uninstall -y transformer-engine && \ + NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ + pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main + +# Install BioNeMo requirements +WORKDIR /workspace/bionemo +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +COPY . . From b6042f17b6c8e203b694d52cc999bbeee8e44054 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 9 Feb 2026 10:24:34 -0800 Subject: [PATCH 17/17] enables tokenizer revision Signed-off-by: Jonathan Mitchell --- bionemo-recipes/recipes/esm2_native_te/dataset.py | 7 +++++-- .../hydra_config/L1_15B_perf_test.yaml | 1 + .../recipes/esm2_native_te/hydra_config/L1_3B.yaml | 1 + .../esm2_native_te/hydra_config/L1_650M.yaml | 2 +- .../esm2_native_te/hydra_config/defaults.yaml | 3 ++- .../recipes/esm2_native_te/modeling_esm_te.py | 1 + .../recipes/esm2_native_te/train_fsdp2.py | 14 +++++++++----- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index 8b8c0f06b..78def1ef9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -42,6 +42,7 @@ def create_tokenized_dataset( max_seq_length: int = 1024, buffer_size: int = 10_000, use_lazy_tokenization: bool = True, + tokenizer_revision: str | None = None, ): """Create a tokenized dataset.""" logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") @@ -56,7 +57,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision="d81c2e5aec37b5e794d0482e3996fb045a137792") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None) def tokenize_function(examples): """Tokenize the protein sequences.""" @@ -167,6 +168,7 @@ def create_thd_dataloader( use_stateful_dataloader: bool = False, mlm_probability: float = 0.15, pad_sequences_to_be_divisible_by: int | None = None, + tokenizer_revision: str | None = None, ): """Create a dataloader that packs up to the maximum number of tokens per batch. @@ -186,7 +188,7 @@ def create_thd_dataloader( mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. This is useful for context parallelism. Defaults to None. - + tokenizer_revision: The revision of the tokenizer to use. Defaults to None. Returns: A dataloader that can be used for training. """ @@ -196,6 +198,7 @@ def create_thd_dataloader( load_dataset_kwargs=load_dataset_kwargs, max_seq_length=max_seq_length, buffer_size=buffer_size, + tokenizer_revision=tokenizer_revision, ) assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml index 0b91c5608..2b6f602e3 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml @@ -8,6 +8,7 @@ num_train_steps: 500 dataset: micro_batch_size: 12 + tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml index e8e47d908..3e055907c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml @@ -8,6 +8,7 @@ num_train_steps: 10_000 dataset: micro_batch_size: 16 + tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index d71fc0375..fc1153dc3 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -8,7 +8,7 @@ num_train_steps: 200 dataset: micro_batch_size: 4 - + tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792" # WandB config wandb_init_args: name: "esm2_t33_650M_UR50D" diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 00287ee83..0cbc27121 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -13,6 +13,7 @@ cp_size: 1 use_sequence_packing: false dataset: tokenizer_name: ${model_tag} + tokenizer_revision: null micro_batch_size: ??? num_workers: 1 max_seq_length: 1024 @@ -93,4 +94,4 @@ quant_stats_config: # Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null -use_fp32_optimizer_weights: false \ No newline at end of file +use_fp32_master_weights: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py index 2053cdd4b..1932c0f4c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -216,6 +216,7 @@ def forward( if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) + import pdb; pdb.set_trace() # If BF16 desired --> use autocast(false) so it goes to BF16. # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index be62b7f5d..cfbe109a4 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -192,7 +192,7 @@ def main(args: DictConfig) -> float | None: fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs ) - config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -216,10 +216,14 @@ def main(args: DictConfig) -> float | None: reduce_dtype=torch.float32, # Gradient reductions in FP32 output_dtype=torch.bfloat16, # Forward output dtype ) - for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) - fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) - + if args.use_fp32_master_weights: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + else: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) # Create a layer map for the transformer stack. layer_number_quantized_recipe_map = {} fp8_layers_set = set(fp8_layers) if fp8_layers else set()