Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bionemo-recipes/recipes/esm2_native_te/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.4
FROM nvcr.io/nvidia/pytorch:25.12-py3
FROM nvcr.io/nvidia/pytorch:25.11-py3

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
Expand Down
22 changes: 22 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# syntax=docker/dockerfile:1.4
FROM nvcr.io/nvidia/pytorch:25.12-py3

# Install sccache for faster builds
RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
apt-get update \
&& apt-get install -y sccache \
&& rm -rf /var/lib/apt/lists/*

# Uninstall pre-installed Transformer Engine and install from source
RUN pip uninstall -y transformer-engine && \
NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \
pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main

# Install BioNeMo requirements
WORKDIR /workspace/bionemo
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
PIP_CONSTRAINT= pip install -r /requirements.txt

COPY . .
7 changes: 5 additions & 2 deletions bionemo-recipes/recipes/esm2_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create_tokenized_dataset(
max_seq_length: int = 1024,
buffer_size: int = 10_000,
use_lazy_tokenization: bool = True,
tokenizer_revision: str | None = None,
):
"""Create a tokenized dataset."""
logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}")
Expand All @@ -56,7 +57,7 @@ def create_tokenized_dataset(
)
dataset = dataset.shuffle(seed=42, buffer_size=buffer_size)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None)

def tokenize_function(examples):
"""Tokenize the protein sequences."""
Expand Down Expand Up @@ -167,6 +168,7 @@ def create_thd_dataloader(
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
tokenizer_revision: str | None = None,
):
"""Create a dataloader that packs up to the maximum number of tokens per batch.

Expand All @@ -186,7 +188,7 @@ def create_thd_dataloader(
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value.
This is useful for context parallelism. Defaults to None.

tokenizer_revision: The revision of the tokenizer to use. Defaults to None.
Returns:
A dataloader that can be used for training.
"""
Expand All @@ -196,6 +198,7 @@ def create_thd_dataloader(
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=max_seq_length,
buffer_size=buffer_size,
tokenizer_revision=tokenizer_revision,
)

assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset."
Expand Down
33 changes: 33 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
example_fp4_tensor_stat_collection:
enabled: True
layers:
# Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming)
# This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)'
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 100
- tensor: gradient
stats: [underflows%, mse]
freq: 100

example_fp8_tensor_stat_collection:
enabled: True
layers:
# Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming)
# This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
freq: 100
- tensor: gradient
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
freq: 100
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv]
layer_types: [layernorm_qkv, proj, fc1, fc2]
transformer_engine:
LogFp8TensorStats:
enabled: True
Expand All @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection:
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 10
LogTensorStats:
enabled: True
stats: [max, min, mean, std, l1_norm]
tensors: [dgrad, wgrad, fprop]
freq: 1
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ num_train_steps: 500

dataset:
micro_batch_size: 12
tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f"

# WandB config
wandb_init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ num_train_steps: 10_000

dataset:
micro_batch_size: 16
tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6"

# WandB config
wandb_init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ num_train_steps: 200

dataset:
micro_batch_size: 4

tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792"
# WandB config
wandb_init_args:
name: "esm2_t33_650M_UR50D"
Expand All @@ -17,3 +17,43 @@ wandb_init_args:

checkpoint:
ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity"

# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
fp8_layers:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 27
- 28
- 29
- 30
- 31
- 32
- 33

fp4_layers:
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26

use_fp32_optimizer_weights: true
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cp_size: 1
use_sequence_packing: false
dataset:
tokenizer_name: ${model_tag}
tokenizer_revision: null
micro_batch_size: ???
num_workers: 1
max_seq_length: 1024
Expand Down Expand Up @@ -51,6 +52,14 @@ fp8_config:
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.

fp4_config:
enabled: false
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
fp4_format: "E2M1"
fp4_recipe_kwargs: {}
fp4_model_init_kwargs:
enabled: false # If this is set to true, fp4_config.enabled must also be set to true.

# Optimizer config
adamw_kwargs:
lr: 4e-4
Expand All @@ -76,7 +85,13 @@ checkpoint:
logger:
frequency: 100

fp8_stats_config:

quant_stats_config:
enabled: false
fp8_stats_file: ./fp8_debugging_stats.yaml
fp8_log_dir: ./log_fp8_stats
quant_stats_file: ./fp8_debugging_stats.yaml
quant_log_dir: ./log_quant_stats

# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
fp8_layers: null
fp4_layers: null
use_fp32_master_weights: null
Loading