feat(zipformer): Add multi-node DDP training support via torchrun/SLURM#2067
feat(zipformer): Add multi-node DDP training support via torchrun/SLURM#2067Houss3m wants to merge 2 commits intok2-fsa:masterfrom
Conversation
This commit enables multi-node distributed training for Zipformer using torchrun and SLURM. The changes are backward-compatible with existing single-node training workflows. Key changes: - train.py: Detect torchrun launch via RANK/WORLD_SIZE env vars - train.py: Use LOCAL_RANK for correct GPU device mapping across nodes - train.py: Pass use_ddp_launch flag to setup_dist for proper init - Add slurm_multinode_ddp.sh: Example SLURM script for multi-node training Usage: # Single-node (unchanged): ./zipformer/train.py --world-size 4 ... # Multi-node via SLURM: sbatch zipformer/slurm_multinode_ddp.sh
Summary of ChangesHello @Houss3m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Zipformer training infrastructure by introducing robust support for multi-node distributed data parallel (DDP) training. It integrates PyTorch's Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds a SLURM orchestration script for multi-node DDP training and updates the Zipformer training entrypoint to detect and correctly initialize distributed runs launched via torchrun/SLURM (uses environment variables like RANK/LOCAL_RANK for device and world-size handling). Changes
Sequence DiagramsequenceDiagram
participant SLURM as SLURM Job Scheduler
participant srun as srun (per-node)
participant torchrun as torchrun
participant train as train.py
participant DDP as PyTorch DDP
SLURM->>SLURM: Read SLURM_NNODES, GPUS_PER_NODE\nCompute WORLD_SIZE, MASTER_ADDR, MASTER_PORT
SLURM->>srun: Launch per-node srun block
srun->>srun: Activate conda env\nSet PYTHONPATH, env vars
srun->>torchrun: Invoke torchrun with rendezvous config
torchrun->>train: Start processes, set RANK/LOCAL_RANK
train->>train: Detect RANK/LOCAL_RANK\nSet CUDA device (LOCAL_RANK)
train->>DDP: Call setup_dist(rank, world_size, master_port)
DDP->>DDP: Initialize backend (NCCL)
train->>train: Run distributed training loop
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for multi-node DDP training using torchrun and SLURM. The changes in train.py correctly detect the launch environment and adjust the DDP setup and device mapping, while maintaining backward compatibility. The new slurm_multinode_ddp.sh script is a good example for multi-node execution. I've provided a few suggestions to improve the robustness and correctness of the SLURM script.
| if [ "$SLURM_PROCID" -eq 0 ]; then | ||
| RDZV_IS_HOST=1 | ||
| else | ||
| RDZV_IS_HOST=0 | ||
| # Small delay to ensure master is ready | ||
| sleep 5 | ||
| fi |
There was a problem hiding this comment.
This block for manually determining the rendezvous host and using sleep 5 is fragile and introduces a potential race condition. If the master node takes longer than 5 seconds to initialize, the job will fail. torchrun is designed to manage the rendezvous process automatically when an rdzv_endpoint is provided. It's recommended to remove this block and the corresponding --rdzv_conf argument on line 133 to rely on torchrun's more robust, built-in synchronization mechanism.
| # management within each node. | ||
| # | ||
| # Usage: | ||
| # sbatch run_multinode_ddp.sh |
| export MASTER_ADDR MASTER_PORT | ||
|
|
||
| # Calculate world size | ||
| GPUS_PER_NODE=8 |
There was a problem hiding this comment.
The number of GPUs per node is hardcoded. This could lead to mismatches if the #SBATCH --gpus-per-node directive is changed. It's more robust to use the SLURM_GPUS_PER_NODE environment variable, which SLURM sets based on the directive.
| GPUS_PER_NODE=8 | |
| GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8} |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh`:
- Around line 24-25: GPUS_PER_NODE is hardcoded and duplicated with the SBATCH
--gpus-per-node directive, risking drift; change the script to derive
GPUS_PER_NODE from the SLURM environment (e.g., use SLURM_GPUS_PER_NODE or a
sensible fallback) and compute WORLD_SIZE from that value so the SBATCH
directive remains the single source of truth; specifically update the assignment
to GPUS_PER_NODE (replace the hardcoded "GPUS_PER_NODE=8") to read GPUS_PER_NODE
from SLURM env and then recalc WORLD_SIZE using that variable (keep references
to GPUS_PER_NODE and WORLD_SIZE in the script unchanged so locations like the
WORLD_SIZE calculation use the new derived value).
- Line 10: Update the usage comment in the script: replace the incorrect example
"sbatch run_multinode_ddp.sh" (the comment line shown) with the correct filename
"sbatch slurm_multinode_ddp.sh" so the usage matches the actual script name
(refer to the comment line containing "sbatch run_multinode_ddp.sh" in
slurm_multinode_ddp.sh).
In `@egs/librispeech/ASR/zipformer/train.py`:
- Around line 1281-1285: The current CUDA device is computed into local_rank and
a torch.device assigned, but the process never calls
torch.cuda.set_device(local_rank), so CUDA defaults can land on GPU 0; after
computing local_rank (from os.environ.get("LOCAL_RANK", rank %
torch.cuda.device_count())) and before or immediately after creating the
torch.device("cuda", local_rank), call torch.cuda.set_device(local_rank) to set
the process default CUDA device so all implicit allocations use the correct GPU;
update the block that references local_rank and device accordingly.
- Around line 1347-1349: The code references local_rank when wrapping the model
in DDP (model = DDP(model, device_ids=[local_rank], ...)) but local_rank is only
defined inside the torch.cuda.is_available() block, which can lead to a
NameError; initialize local_rank before that conditional (e.g., set local_rank =
0 or None) and then if CUDA is available overwrite it inside the
torch.cuda.is_available() branch so that the DDP wrapping (in the world_size > 1
branch) can safely refer to local_rank; ensure DDP call still uses device_ids
appropriately when CUDA is unavailable (pass None or an empty list) to avoid
invalid device references.
- Around line 1595-1610: The code sets params from CLI via
params.update(vars(args)) which leaves params.world_size as args.world_size
(default 1) even when torchrun passes env_world_size to run(); this breaks
get_adjusted_batch_count() scaling. Fix by overriding the world-size after
params are loaded: inside run() (the function that takes rank and world_size)
set params.world_size = world_size immediately after params.update(vars(args));
alternatively, if you prefer the change in main(), set args.world_size =
env_world_size before calling run() when env_rank != -1 so params.update
receives the correct world size.
🧹 Nitpick comments (4)
egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh (4)
61-61: Usemapfileorread -ainstead of array assignment from command substitution.Per ShellCheck SC2207, word-splitting into an array via
HOSTS=($(…))is fragile. While hostnames are unlikely to contain spaces, it's better practice to usemapfile.Proposed fix
-HOSTS=($(scontrol show hostnames "${SLURM_JOB_NODELIST}")) +mapfile -t HOSTS < <(scontrol show hostnames "${SLURM_JOB_NODELIST}")
66-67:MASTER_ADDRresolution usessrunwhich may be slow or fail; considerhostname -iorgetent.Using
srun -N1 -n1 -wto resolve the master's IP adds a SLURM scheduling step that can delay startup. A simpler and more robust alternative:-MASTER_ADDR=$(srun -N1 -n1 -w "${MASTER_NODE}" bash -lc \ - "ip -o -4 addr show scope global | awk '{print \$4}' | cut -d/ -f1 | head -n1") +MASTER_ADDR=$(getent hosts "${MASTER_NODE}" | awk '{print $1}' | head -n1)Also, there's no validation that
MASTER_ADDRwas resolved successfully. If it's empty, the training will fail with a confusing error later.Proposed fix with validation
-MASTER_ADDR=$(srun -N1 -n1 -w "${MASTER_NODE}" bash -lc \ - "ip -o -4 addr show scope global | awk '{print \$4}' | cut -d/ -f1 | head -n1") +MASTER_ADDR=$(getent hosts "${MASTER_NODE}" | awk '{print $1}' | head -n1) +if [ -z "${MASTER_ADDR}" ]; then + echo "ERROR: Could not resolve MASTER_ADDR for ${MASTER_NODE}" >&2 + exit 1 +fi
136-136:--world-sizepassed totrain.pymay conflict with environment-based world size.In
train.py'smain(), whenRANKis set,env_world_sizefrom theWORLD_SIZEenv var is used (set bytorchrun). The--world-sizearg passed here on line 136 goes intoargs.world_sizebut is never used in the torchrun path. This is harmless but potentially confusing — a user might change this value expecting it to have effect.Consider adding a comment clarifying that
--world-sizeis ignored when launched via torchrun, or remove it from the torchrun invocation.
106-146: Thesleep 5delay may be insufficient in practice, but the explicitis_hostconfiguration should be retained.The
sleep 5(line 123) is a reasonable attempt at synchronization, but 5 seconds may not be reliable on slower clusters. However, PyTorch's c10d rendezvous backend handles connection failures with aread_timeout(default 60s), which gives retrying nodes time to connect without strict timing.Regarding
--rdzv_conf is_host=: This is a standard and documented parameter for the c10d rendezvous backend. While PyTorch can infer the host heuristically from the rendezvous endpoint, explicitly settingis_hostis a recommended practice to avoid hostname/IP mismatch issues, especially in multi-node SLURM environments where node names and IPs may not align predictably. The current approach—usingSLURM_PROCIDto determine which node hosts the store—is sound.If you want to improve robustness, consider increasing the delay or replacing it with a retry loop in the training script that waits for the rendezvous endpoint to become available before calling
torchrun. Alternatively, rely on torchrun's built-in timeout and let worker nodes handle connection failures gracefully.
SLURM script fixes: - Fix typo in usage comment (run_multinode_ddp.sh -> slurm_multinode_ddp.sh) - Use SLURM_GPUS_PER_NODE env var instead of hardcoded value - Remove fragile sleep/rdzv_conf block, rely on torchrun's built-in sync train.py fixes: - Add torch.cuda.set_device(device) to prevent GPU 0 allocation issues - Initialize local_rank before conditional to avoid potential NameError - Override params.world_size with actual world_size for correct scheduling
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh`:
- Around line 75-76: SLURM_GPUS_PER_NODE can contain a type prefix like
"a100:8", which breaks the arithmetic for WORLD_SIZE; update the GPUS_PER_NODE
assignment to extract the numeric count (e.g., use shell parameter expansion to
take the suffix after ':' or fall back to the original value) before computing
WORLD_SIZE so WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) always uses a plain
integer; modify the code that sets GPUS_PER_NODE (and leave WORLD_SIZE
calculation unchanged) to strip any "type:" prefix from SLURM_GPUS_PER_NODE
(reference variables: SLURM_GPUS_PER_NODE, GPUS_PER_NODE, WORLD_SIZE).
🧹 Nitpick comments (4)
egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh (3)
61-61: Usemapfileinstead of array assignment from command substitution.Per ShellCheck SC2207, word-splitting-based array assignment can break on entries with spaces (unlikely here, but good hygiene).
Proposed fix
-HOSTS=($(scontrol show hostnames "${SLURM_JOB_NODELIST}")) +mapfile -t HOSTS < <(scontrol show hostnames "${SLURM_JOB_NODELIST}")
66-67: Spawning a fullsrunjob just to resolve the master node IP adds latency and can fail.This
sruncall allocates a task step solely to runip addron the master node. A simpler and more robust approach is to usegetent hostsorhostname -Iwithoutsrun:Proposed fix
-MASTER_ADDR=$(srun -N1 -n1 -w "${MASTER_NODE}" bash -lc \ - "ip -o -4 addr show scope global | awk '{print \$4}' | cut -d/ -f1 | head -n1") +MASTER_ADDR=$(getent ahostsv4 "${MASTER_NODE}" | awk 'NR==1{print $1}')
106-135: The quoting pattern for variable injection into thesruninline script is correct but fragile.The
'"$VAR"'idiom (break out of single-quote, insert double-quoted expansion, re-enter single-quote) works, but it's easy to get wrong on edits. Consider using a heredoc or exporting variables and referencing them directly inside thesrunblock, which would simplify maintenance:export SLURM_NNODES GPUS_PER_NODE WORLD_SIZE MASTER_ADDR MASTER_PORT export EXP_DIR NUM_EPOCHS MAX_DURATION CAUSAL CHUNK_SIZE LEFT_CONTEXT_FRAMES BPE_MODEL srun --ntasks=${SLURM_NNODES} --ntasks-per-node=1 --kill-on-bad-exit=1 --export=ALL bash -lc ' ... torchrun \ --nnodes="$SLURM_NNODES" \ --nproc_per_node="$GPUS_PER_NODE" \ ... 'Since
--export=ALLis already specified, all exported variables are available inside thesrunblock without the quoting gymnastics.egs/librispeech/ASR/zipformer/train.py (1)
1600-1615: Consider guarding againstRANKset withoutWORLD_SIZE.If
RANKis set (line 1602) butWORLD_SIZEis not (or vice versa),env_world_sizedefaults to-1and gets passed through torun(), which would cause incorrect behavior (e.g.,setup_distcalled withworld_size=-1, negativeparams.world_size). While torchrun always sets both, a defensive check would prevent confusing failures if the environment is partially configured.Proposed fix
env_rank = int(os.environ.get("RANK", -1)) env_world_size = int(os.environ.get("WORLD_SIZE", -1)) - if env_rank != -1: + if env_rank != -1 and env_world_size != -1: # Multi-node/torchrun mode: bypass mp.spawn # We use world_size from environment, not from args run(rank=env_rank, world_size=env_world_size, args=args) + elif env_rank != -1 or env_world_size != -1: + raise RuntimeError( + f"Partial distributed env detected: RANK={env_rank}, WORLD_SIZE={env_world_size}. " + "Both must be set when using torchrun/SLURM." + ) else:
| GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8} | ||
| WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) |
There was a problem hiding this comment.
SLURM_GPUS_PER_NODE may contain a type prefix (e.g., a100:8), breaking arithmetic.
When --gpus-per-node is specified with a GPU type (e.g., --gpus-per-node=a100:8), SLURM sets SLURM_GPUS_PER_NODE=a100:8. The arithmetic expansion on line 76 would then fail with a syntax error. Consider stripping the type prefix:
Proposed fix
-GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8}
+# Strip optional GPU type prefix (e.g., "a100:8" -> "8")
+_slurm_gpn="${SLURM_GPUS_PER_NODE:-8}"
+GPUS_PER_NODE="${_slurm_gpn##*:}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8} | |
| WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) | |
| # Strip optional GPU type prefix (e.g., "a100:8" -> "8") | |
| _slurm_gpn="${SLURM_GPUS_PER_NODE:-8}" | |
| GPUS_PER_NODE="${_slurm_gpn##*:}" | |
| WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) |
🤖 Prompt for AI Agents
In `@egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh` around lines 75 - 76,
SLURM_GPUS_PER_NODE can contain a type prefix like "a100:8", which breaks the
arithmetic for WORLD_SIZE; update the GPUS_PER_NODE assignment to extract the
numeric count (e.g., use shell parameter expansion to take the suffix after ':'
or fall back to the original value) before computing WORLD_SIZE so
WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) always uses a plain integer;
modify the code that sets GPUS_PER_NODE (and leave WORLD_SIZE calculation
unchanged) to strip any "type:" prefix from SLURM_GPUS_PER_NODE (reference
variables: SLURM_GPUS_PER_NODE, GPUS_PER_NODE, WORLD_SIZE).
This commit enables multi-node distributed training for Zipformer using torchrun and SLURM. The changes are backward-compatible with existing single-node training workflows.
Key changes:
Usage:
Single-node (unchanged): ./zipformer/train.py --world-size 4 ...
Multi-node via SLURM: sbatch zipformer/slurm_multinode_ddp.sh
Summary by CodeRabbit
New Features
Infrastructure