From 7a597daec9c0460efc4d3017e78e4d3ee9e15b4d Mon Sep 17 00:00:00 2001 From: jwilber Date: Wed, 4 Feb 2026 15:29:57 -0800 Subject: [PATCH] Add fp8 support Signed-off-by: jwilber --- .../configs/recipes/esm2_native_te_3b.yaml | 24 ++++++++++++++++++- .../configs/recipes/esm2_native_te_650m.yaml | 16 ++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/ci/lepton/model_convergence/configs/recipes/esm2_native_te_3b.yaml b/ci/lepton/model_convergence/configs/recipes/esm2_native_te_3b.yaml index edc94fa52..2aa59226d 100644 --- a/ci/lepton/model_convergence/configs/recipes/esm2_native_te_3b.yaml +++ b/ci/lepton/model_convergence/configs/recipes/esm2_native_te_3b.yaml @@ -91,6 +91,26 @@ products: thd_enabled: false wandb_name: "${config}__${now:%Y%m%d-%H%M%S}__${gitsha:}" job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2" + # TE bshd perf, FSDP2, FP8 + - config: L1_3B + task_cmd: train_fsdp2 + parallelism_strategy: fsdp2 + thd_enabled: false + fp8_enabled: true + fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling + fp8_format: E4M3 + wandb_name: "${config}__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}" + job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2-fp8" + # TE thd perf, FSDP2, FP8 + - config: L1_3B + task_cmd: train_fsdp2 + parallelism_strategy: fsdp2 + thd_enabled: true + fp8_enabled: true + fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling + fp8_format: E4M3 + wandb_name: "${config}__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}" + job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2-thd-fp8" - config: L1_3B task_cmd: train_mfsdp parallelism_strategy: mfsdp @@ -139,4 +159,6 @@ run_script: | checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \ +checkpoint.save_checkpoints=${save_checkpoints} \ +checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \ - fp8_config.enabled=${fp8_enabled} + fp8_config.enabled=${fp8_enabled} \ + fp8_config.fp8_recipe=${fp8_recipe} \ + fp8_config.fp8_format=${fp8_format} diff --git a/ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml b/ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml index e4cb90e4c..7e7b50123 100644 --- a/ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml +++ b/ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml @@ -90,6 +90,18 @@ products: micro_batch_size: 48 wandb_name: "esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}" job_name: "esm2-native-650m-fsdp2-thd" + - config: L1_650M + num_nodes: 2 + num_devices: 8 + task_cmd: train_fsdp2 + parallelism_strategy: fsdp2 + thd_enabled: true + micro_batch_size: 48 + fp8_enabled: true + fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling + fp8_format: E4M3 + wandb_name: "esm2_native_650m__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}" + job_name: "esm2-native-650m-fsdp2-thd" # OSS Convergence Baseline # - config: L1_650M # model_tag: facebook/esm2_t33_650M_UR50D @@ -137,4 +149,6 @@ run_script: | checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \ +checkpoint.save_checkpoints=${save_checkpoints} \ +checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \ - fp8_config.enabled=${fp8_enabled} + fp8_config.enabled=${fp8_enabled} \ + fp8_config.fp8_recipe=${fp8_recipe} \ + fp8_config.fp8_format=${fp8_format}