Megatron-LM style 3D parallelism in ~300 lines. Educational — every abstraction is explicit.
Vanilla DDP only parallelizes the data. When a single model layer doesn't fit on one GPU, you need to split the model itself. Megatron-LM does this two ways simultaneously:
World = TP × PP × DP
TP (Tensor Parallel) — splits weight matrices within a layer across GPUs
PP (Pipeline Parallel) — splits layers into sequential stages across GPUs
DP (Data Parallel) — splits the batch, like standard DDP
# 8 GPUs — full 3D parallelism
torchrun --nproc_per_node=8 megatron_demo.py --tp 2 --pp 2 --dp 2
# 4 GPUs
torchrun --nproc_per_node=4 megatron_demo.py --tp 2 --pp 2 --dp 1
# 2 GPUs, CPU
torchrun --nproc_per_node=2 megatron_demo.py --tp 2 --pp 1 --dp 1 --cpuworld_size must equal tp × pp × dp.
pip install torchSplits individual weight matrices across GPUs. For a Linear(H, 4H) layer with TP=2:
- Rank 0 holds columns
0 .. 2H-1 - Rank 1 holds columns
2H .. 4H-1
Done in column-parallel / row-parallel pairs to minimize communication — only one AllReduce per transformer block instead of two:
Input → ColumnParallelLinear (no comm) → GELU → RowParallelLinear (AllReduce)
For attention, heads are split across TP ranks. Each rank computes its local heads, then RowParallelLinear reduces the output projection.
Splits transformer blocks across GPUs. With PP=2 and 4 layers:
- Stage 0 (GPU 0): layers 0–1 + embeddings
- Stage 1 (GPU 1): layers 2–3 + LM head
Forward pass passes activations via dist.send / dist.recv. Backward flows in reverse.
This demo uses GPipe scheduling (full forward, then full backward). Production Megatron uses 1F1B (one-forward-one-backward) to keep all stages busy and reduce the number of activations held in memory.
Same as DDP — each DP replica sees a different micro-batch. After backward(), gradients are AllReduced across DP replicas. TP and PP ranks within the same DP group always process the same data.
With 8 GPUs (TP=2, PP=2, DP=2), rank assignments look like this:
DP replica 0:
Stage 0: GPU 0 (tp=0), GPU 1 (tp=1)
Stage 1: GPU 2 (tp=0), GPU 3 (tp=1)
DP replica 1:
Stage 0: GPU 4 (tp=0), GPU 5 (tp=1)
Stage 1: GPU 6 (tp=0), GPU 7 (tp=1)
TP groups: [0,1], [2,3], [4,5], [6,7] ← AllReduce every forward
PP groups: [0,2], [1,3], [4,6], [5,7] ← send/recv activations
DP groups: [0,4], [1,5], [2,6], [3,7] ← AllReduce after backward
Each GPU participates in three separate communication groups simultaneously.
| Strategy | Communication per step |
|---|---|
| DDP only | 1× AllReduce (gradients) |
| + Tensor Parallel | + 1× AllReduce per block (activations, forward+backward) |
| + Pipeline Parallel | + send/recv between stages (activations) |
TP adds communication in the forward pass (fast, NVLink-friendly).
PP adds communication between stages (smaller tensors, but serialized).
DP adds communication in the backward pass (large, but only once per step).
| This demo | Real Megatron-LM |
|---|---|
| GPipe schedule | 1F1B interleaved schedule |
Explicit sync_dp_gradients() |
Gradient hooks (overlapped with backward) |
| Random data | Efficient data pipeline with DistributedSampler |
| Single file | Full training framework with checkpointing, activation recomputation, FP16/BF16 |
| No sequence parallelism | Sequence parallel LayerNorm + Dropout to reduce activation memory |