-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_groups.py
More file actions
77 lines (65 loc) · 3.25 KB
/
process_groups.py
File metadata and controls
77 lines (65 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
process_groups.py
-----------------
Builds the three overlapping communication groups Megatron needs.
Every GPU belongs to exactly one group of each type simultaneously.
With 8 GPUs (TP=2, PP=2, DP=2), the rank layout is:
DP replica 0:
Stage 0: GPU 0 (tp=0), GPU 1 (tp=1)
Stage 1: GPU 2 (tp=0), GPU 3 (tp=1)
DP replica 1:
Stage 0: GPU 4 (tp=0), GPU 5 (tp=1)
Stage 1: GPU 6 (tp=0), GPU 7 (tp=1)
TP groups: [0,1], [2,3], [4,5], [6,7] ← AllReduce every forward pass
PP groups: [0,2], [1,3], [4,6], [5,7] ← send/recv activations
DP groups: [0,4], [1,5], [2,6], [3,7] ← AllReduce after backward
"""
import torch.distributed as dist
class ProcessGroups:
def __init__(self, tp: int, pp: int, dp: int):
assert dist.is_initialized()
world = dist.get_world_size()
assert world == tp * pp * dp, \
f"world_size {world} must equal tp({tp}) * pp({pp}) * dp({dp})"
self.tp_size = tp
self.pp_size = pp
self.dp_size = dp
rank = dist.get_rank()
# ── Tensor Parallel groups ──────────────────────────────────────────
# Same PP stage + DP replica, different TP slice.
# These ranks share weight shards and AllReduce activations each forward.
self.tp_group, self.tp_rank = None, None
for dp_r in range(dp):
for pp_r in range(pp):
ranks = [dp_r * pp * tp + pp_r * tp + tp_r for tp_r in range(tp)]
grp = dist.new_group(ranks)
if rank in ranks:
self.tp_group = grp
self.tp_rank = ranks.index(rank)
# ── Pipeline Parallel groups ────────────────────────────────────────
# Same DP replica + TP slice, different PP stage.
# These ranks pass activations forward and gradients backward.
self.pp_group, self.pp_rank = None, None
for dp_r in range(dp):
for tp_r in range(tp):
ranks = [dp_r * pp * tp + pp_r * tp + tp_r for pp_r in range(pp)]
grp = dist.new_group(ranks)
if rank in ranks:
self.pp_group = grp
self.pp_rank = ranks.index(rank)
# ── Data Parallel groups ────────────────────────────────────────────
# Same PP stage + TP slice, different DP replica.
# These ranks hold identical model shards and sync gradients after backward.
self.dp_group, self.dp_rank = None, None
for pp_r in range(pp):
for tp_r in range(tp):
ranks = [dp_r * pp * tp + pp_r * tp + tp_r for dp_r in range(dp)]
grp = dist.new_group(ranks)
if rank in ranks:
self.dp_group = grp
self.dp_rank = ranks.index(rank)
self.is_first_stage = (self.pp_rank == 0)
self.is_last_stage = (self.pp_rank == pp - 1)
def log(self, rank: int):
if rank == 0:
print(f" Process groups | TP={self.tp_size} PP={self.pp_size} DP={self.dp_size}")