Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ products:
thd_enabled: false
wandb_name: "${config}__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2"
# TE bshd perf, FSDP2, FP8
- config: L1_3B
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: false
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
fp8_format: E4M3
wandb_name: "${config}__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2-fp8"
# TE thd perf, FSDP2, FP8
- config: L1_3B
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: true
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
fp8_format: E4M3
wandb_name: "${config}__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "${sanitize:${recipe_subdir}-${config}}-fsdp2-thd-fp8"
- config: L1_3B
task_cmd: train_mfsdp
parallelism_strategy: mfsdp
Expand Down Expand Up @@ -139,4 +159,6 @@ run_script: |
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
+checkpoint.save_checkpoints=${save_checkpoints} \
+checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
fp8_config.enabled=${fp8_enabled}
fp8_config.enabled=${fp8_enabled} \
fp8_config.fp8_recipe=${fp8_recipe} \
fp8_config.fp8_format=${fp8_format}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ products:
micro_batch_size: 48
wandb_name: "esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-650m-fsdp2-thd"
- config: L1_650M
num_nodes: 2
num_devices: 8
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: true
micro_batch_size: 48
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
fp8_format: E4M3
wandb_name: "esm2_native_650m__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-650m-fsdp2-thd"
# OSS Convergence Baseline
# - config: L1_650M
# model_tag: facebook/esm2_t33_650M_UR50D
Expand Down Expand Up @@ -137,4 +149,6 @@ run_script: |
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
+checkpoint.save_checkpoints=${save_checkpoints} \
+checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
fp8_config.enabled=${fp8_enabled}
fp8_config.enabled=${fp8_enabled} \
fp8_config.fp8_recipe=${fp8_recipe} \
fp8_config.fp8_format=${fp8_format}