Skip to content

Commit 7cb1b88

Browse files
JoshWoo2003delocksfc-gh-truwaseAntlera
authored
Add ZenFlow code for Stage 3 (#7516)
This PR completes the ZenFlow integration for DeepSpeed ZeRO Stage 3. Highlights: - ZenFlowSelectiveAdamW_stage3: Optimizer with importance-aware selective parameter updates for ZeRO Stage 3. - ZenFlowZeroOptimizer_Stage3: Full Stage 3 optimizer integration with partitioned parameters and CPU offload. - Configurable via ZenFlowConfig, fully integrated with DeepSpeedZeroConfig for Stage 3. - Unit tests for Stage 3 cases ensuring correctness and compatibility. Note: Intergration with ZeRO Stage 1&2 was introduced in #7391 --------- Signed-off-by: Yusen Wu <xrn4ub@virginia.edu> Co-authored-by: Ma, Guokai <guokai.ma@intel.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Tingfeng Lan <erc8gx@virginia.edu>
1 parent b7cd78f commit 7cb1b88

File tree

10 files changed

+1218
-234
lines changed

10 files changed

+1218
-234
lines changed

deepspeed/ops/adam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .cpu_adam import DeepSpeedCPUAdam
77
from .fused_adam import FusedAdam
88
from .zenflow_cpu_adam import ZenFlowCPUAdam
9-
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
9+
from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3

deepspeed/ops/adam/zenflow_torch_adam.py

Lines changed: 251 additions & 33 deletions
Large diffs are not rendered by default.

deepspeed/runtime/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,7 @@ def _configure_zero_optimizer(self, optimizer):
18681868
overlap_comm=self.zero_overlap_comm(),
18691869
offload_optimizer_config=self.zero_offload_optimizer(),
18701870
offload_param_config=self.zero_offload_param(),
1871+
zenflow_config=self.zenflow_config(),
18711872
sub_group_size=self.zero_sub_group_size(),
18721873
offload_ratio=self.zero_partial_offload(),
18731874
mpu=self.mpu,

deepspeed/runtime/zenflow/engine_stage3.py

Lines changed: 641 additions & 0 deletions
Large diffs are not rendered by default.

deepspeed/runtime/zenflow/zenflow_stage_1_and_2.py

Lines changed: 3 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33

44
# DeepSpeed Team
55

6-
import os
7-
import math
8-
import psutil
96
import torch
107
from deepspeed import comm as dist
11-
import torch.multiprocessing as mp
128

139
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
10+
from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process
1411
from deepspeed.runtime.utils import (see_memory_usage)
1512
from deepspeed.ops.adam import ZenFlowSelectiveAdamW
1613

@@ -97,6 +94,8 @@ def __init__(self,
9794
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
9895
self.offload_selective_optimizer = zenflow_config.offload
9996
self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
97+
self.start_optimizer_process = lambda: start_optimizer_process(self)
98+
self.zf_stage3 = False
10099

101100
if self.offload_selective_optimizer:
102101
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
@@ -636,64 +635,10 @@ def zenflow_cpu_optimizer_step(self, group_no):
636635
self.optimizer.step(step_id=self.micro_step + 1)
637636

638637

639-
def disable_accelerator():
640-
accelerator = get_accelerator()
641-
accelerator.is_available = lambda: False
642-
accelerator.device_count = lambda: 0
643-
accelerator.current_device = lambda: -1
644-
# Optionally mark it as initialized if needed
645-
if hasattr(accelerator, "_initialized"):
646-
accelerator._initialized = True
647-
648-
649-
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
650-
shared_stale_param_map, zf_affinity):
651-
disable_accelerator()
652-
653-
current_process = psutil.Process()
654-
current_process.cpu_affinity(zf_affinity)
655-
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
656-
657-
from deepspeed.ops.adam import ZenFlowCPUAdam
658-
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
659-
660-
pipe.send({"type": "ready"})
661-
662-
# TODO: replace this with rpc
663-
664-
while True:
665-
cmd = pipe.recv()
666-
if cmd["type"] == "step":
667-
now_state = cmd["now_state"]
668-
micro_step = cmd["micro_step"]
669-
group_infos = cmd["group_infos"]
670-
671-
for group_no, group_info in enumerate(group_infos):
672-
original_param_groups = optimizer.param_groups
673-
optimizer.param_groups = [original_param_groups[group_no]]
674-
group = optimizer.param_groups[0]
675-
676-
for param_idx, param in enumerate(group["params"]):
677-
key = (group_no, param_idx)
678-
if key in shared_overlap_grad_map:
679-
param.overlap_grad = shared_overlap_grad_map[key]
680-
if key in shared_stale_param_map:
681-
param.stale_param = shared_stale_param_map[key]
682-
683-
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
684-
685-
optimizer.param_groups = original_param_groups
686-
687-
pipe.send({"type": "done"})
688-
elif cmd["type"] == "exit":
689-
break
690-
691-
692638
class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
693639

694640
def __init__(self, *args, **kwargs):
695641
super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs)
696-
self.process_pool = mp.Pool(1)
697642
self.process_optimizer_established = False
698643
self.first_update_round_after_warmup = True
699644

@@ -759,85 +704,6 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
759704
dest_tensor.copy_(src_tensor, non_blocking=True)
760705
param.grad = None #offload only
761706

762-
# check if all tensors in the list are equal to each other
763-
def all_tensors_equal(self, tensor_list):
764-
first_tensor = tensor_list[0]
765-
for tensor in tensor_list[1:]:
766-
if not torch.equal(first_tensor, tensor):
767-
return False
768-
return True
769-
770-
def start_optimizer_process(self):
771-
from multiprocessing import Pipe, get_context, Manager
772-
773-
ctx = get_context("spawn")
774-
self.parent_conn, self.child_conn = Pipe()
775-
776-
manager = Manager()
777-
self.shared_overlap_grad_map = manager.dict()
778-
self.shared_stale_param_map = manager.dict()
779-
780-
for group_no, group in enumerate(self.optimizer.param_groups):
781-
for param_idx, param in enumerate(group['params']):
782-
param.data.share_memory_()
783-
if not hasattr(param, 'stale_param'):
784-
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
785-
param.stale_param.data.share_memory_()
786-
key = (group_no, param_idx)
787-
self.shared_stale_param_map[key] = param.stale_param
788-
if param.overlap_grad is not None:
789-
param.overlap_grad[0].data.share_memory_()
790-
param.overlap_grad[1].data.share_memory_()
791-
key = (group_no, param_idx)
792-
self.shared_overlap_grad_map[key] = param.overlap_grad
793-
794-
param_groups_data = self.optimizer.param_groups
795-
curr_rank = dist.get_rank()
796-
total_rank = dist.get_world_size()
797-
798-
current_process = psutil.Process()
799-
current_affinity = current_process.cpu_affinity()
800-
all_affinities = [
801-
torch.zeros(len(current_affinity),
802-
dtype=type(current_affinity[0]),
803-
device=get_accelerator().current_device_name()) for _ in range(total_rank)
804-
]
805-
dist.all_gather(
806-
all_affinities,
807-
torch.tensor(current_affinity,
808-
dtype=type(current_affinity[0]),
809-
device=get_accelerator().current_device_name()))
810-
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
811-
if self.all_tensors_equal(all_affinities):
812-
num_phy_cores = psutil.cpu_count(logical=False)
813-
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
814-
num_available_phy_cores = len(available_phy_cores)
815-
my_rank = curr_rank
816-
my_size = total_rank
817-
cores_per_rank = num_available_phy_cores // my_size
818-
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
819-
pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity))
820-
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
821-
zf_affinity = current_affinity[pt_num_cores:]
822-
pt_affinity = current_affinity[:pt_num_cores]
823-
else:
824-
zf_affinity = current_affinity
825-
pt_affinity = current_affinity
826-
self.process = ctx.Process(
827-
target=zenflow_optimizer_process,
828-
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
829-
self.shared_stale_param_map, zf_affinity),
830-
)
831-
self.process.daemon = True
832-
self.process.start()
833-
current_process.cpu_affinity(pt_affinity)
834-
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
835-
836-
msg = self.parent_conn.recv()
837-
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
838-
839-
self.process_optimizer_established = True
840-
841707
def wait_last_update_and_copy(self):
842708

843709
if not hasattr(self, 'parent_conn'):

deepspeed/runtime/zenflow/zenflow_utils.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
# DeepSpeed Team
55

6+
import os
7+
import math
68
import torch
9+
import psutil
10+
from deepspeed import comm as dist
11+
from deepspeed.accelerator import get_accelerator
712

813

914
def _flatten_dense_tensors(tensors):
@@ -40,3 +45,147 @@ def _unflatten_dense_tensors(flat, tensors):
4045
transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors]
4146
unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors)
4247
return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat]
48+
49+
50+
def disable_accelerator():
51+
accelerator = get_accelerator()
52+
accelerator.is_available = lambda: False
53+
accelerator.device_count = lambda: 0
54+
accelerator.current_device = lambda: -1
55+
# Optionally mark it as initialized if needed
56+
if hasattr(accelerator, "_initialized"):
57+
accelerator._initialized = True
58+
59+
60+
def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity):
61+
disable_accelerator()
62+
63+
current_process = psutil.Process()
64+
current_process.cpu_affinity(zf_affinity)
65+
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
66+
67+
from deepspeed.ops.adam import ZenFlowCPUAdam
68+
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
69+
70+
pipe.send({"type": "ready"})
71+
72+
# TODO: replace this with rpc
73+
74+
while True:
75+
cmd = pipe.recv()
76+
if cmd["type"] == "step":
77+
now_state = cmd["now_state"]
78+
micro_step = cmd["micro_step"]
79+
group_infos = cmd["group_infos"]
80+
81+
for group_no, group_info in enumerate(group_infos):
82+
original_param_groups = optimizer.param_groups
83+
optimizer.param_groups = [original_param_groups[group_no]]
84+
group = optimizer.param_groups[0]
85+
86+
for param_idx, param in enumerate(group["params"]):
87+
key = (group_no, param_idx)
88+
if key in shared_overlap_grad_map:
89+
param.overlap_grad = shared_overlap_grad_map[key]
90+
if key in shared_stale_param_map:
91+
param.stale_param = shared_stale_param_map[key]
92+
93+
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
94+
95+
optimizer.param_groups = original_param_groups
96+
97+
pipe.send({"type": "done"})
98+
elif cmd["type"] == "exit":
99+
break
100+
101+
102+
def all_tensors_equal(tensor_list):
103+
first_tensor = tensor_list[0]
104+
for tensor in tensor_list[1:]:
105+
if not torch.equal(first_tensor, tensor):
106+
return False
107+
return True
108+
109+
110+
def start_optimizer_process(zf_optimizer):
111+
from multiprocessing import Pipe, get_context, Manager
112+
113+
ctx = get_context("spawn")
114+
zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe()
115+
116+
manager = Manager()
117+
zf_optimizer.shared_overlap_grad_map = manager.dict()
118+
zf_optimizer.shared_stale_param_map = manager.dict()
119+
120+
if zf_optimizer.zf_stage3:
121+
params_iter = [((group_no, 0), param)
122+
for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)]
123+
else:
124+
params_iter = [((group_no, param_idx), param)
125+
for group_no, group in enumerate(zf_optimizer.optimizer.param_groups)
126+
for param_idx, param in enumerate(group["params"])]
127+
128+
for key, param in params_iter:
129+
param.data.share_memory_()
130+
131+
if not hasattr(param, "stale_param"):
132+
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
133+
param.stale_param.data.share_memory_()
134+
zf_optimizer.shared_stale_param_map[key] = param.stale_param
135+
136+
if getattr(param, "overlap_grad", None) is not None:
137+
param.overlap_grad[0].data.share_memory_()
138+
param.overlap_grad[1].data.share_memory_()
139+
zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad
140+
141+
param_groups_data = ([{
142+
"params": [param]
143+
} for param in zf_optimizer.fp32_partitioned_groups_flat]
144+
if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups)
145+
146+
curr_rank = dist.get_rank()
147+
total_rank = dist.get_world_size()
148+
149+
current_process = psutil.Process()
150+
current_affinity = current_process.cpu_affinity()
151+
all_affinities = [
152+
torch.zeros(len(current_affinity),
153+
dtype=type(current_affinity[0]),
154+
device=get_accelerator().current_device_name()) for _ in range(total_rank)
155+
]
156+
dist.all_gather(
157+
all_affinities,
158+
torch.tensor(current_affinity, dtype=type(current_affinity[0]),
159+
device=get_accelerator().current_device_name()))
160+
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
161+
if all_tensors_equal(all_affinities):
162+
num_phy_cores = psutil.cpu_count(logical=False)
163+
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
164+
num_available_phy_cores = len(available_phy_cores)
165+
my_rank = curr_rank
166+
my_size = total_rank
167+
cores_per_rank = num_available_phy_cores // my_size
168+
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
169+
pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity))
170+
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
171+
zf_affinity = current_affinity[pt_num_cores:]
172+
pt_affinity = current_affinity[:pt_num_cores]
173+
else:
174+
zf_affinity = current_affinity
175+
pt_affinity = current_affinity
176+
177+
zf_optimizer.process = ctx.Process(
178+
target=zenflow_optimizer_process,
179+
args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map,
180+
zf_optimizer.shared_stale_param_map, zf_affinity),
181+
)
182+
zf_optimizer.process.daemon = True
183+
zf_optimizer.process.start()
184+
185+
current_process.cpu_affinity(pt_affinity)
186+
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
187+
188+
msg = zf_optimizer.parent_conn.recv()
189+
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
190+
191+
zf_optimizer.process_optimizer_established = True

0 commit comments

Comments
 (0)