Skip to content

Commit bce706c

Browse files
Cherry-pick #3399 for Mamba Uneven PP fix (#3544)
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 21c1053 commit bce706c

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

megatron/core/ssm/mamba_block.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,53 @@ def __init__(
170170
)
171171

172172
def _select_layers_for_pipeline_parallel(self, layer_type_list):
173-
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()
174-
175173
assert self.config.virtual_pipeline_model_parallel_size is None, (
176174
"The Mamba hybrid model does not currently support "
177175
"virtual/interleaved pipeline parallelism"
178176
)
179177

180-
offset = self.pp_group.rank() * num_layers_per_pipeline_rank
181-
selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank]
178+
pp_rank = self.pp_group.rank()
179+
pp_size = self.pp_group.size()
180+
181+
num_layers_in_first = self.config.num_layers_in_first_pipeline_stage
182+
num_layers_in_last = self.config.num_layers_in_last_pipeline_stage
183+
184+
if num_layers_in_first is not None or num_layers_in_last is not None:
185+
# Uneven pipeline parallelism: mirror the logic in
186+
# get_transformer_layer_offset so that MambaStack and
187+
# TransformerLayer agree on layer placement.
188+
first = 0 if num_layers_in_first is None else num_layers_in_first
189+
last = 0 if num_layers_in_last is None else num_layers_in_last
190+
middle_num_layers = self.config.num_layers - first - last
191+
192+
middle_pipeline_stages = pp_size - sum(
193+
1 for x in (num_layers_in_first, num_layers_in_last) if x is not None
194+
)
195+
196+
if middle_pipeline_stages > 0:
197+
layers_per_middle = middle_num_layers // middle_pipeline_stages
198+
else:
199+
layers_per_middle = 0
200+
201+
is_first_stage = num_layers_in_first is not None and pp_rank == 0
202+
is_last_stage = num_layers_in_last is not None and pp_rank == pp_size - 1
203+
204+
if is_first_stage:
205+
offset = 0
206+
num_layers_this_rank = first
207+
elif is_last_stage:
208+
offset = self.config.num_layers - last
209+
num_layers_this_rank = last
210+
else:
211+
middle_rank = pp_rank if num_layers_in_first is None else pp_rank - 1
212+
offset = middle_rank * layers_per_middle + first
213+
num_layers_this_rank = layers_per_middle
214+
else:
215+
num_layers_per_pipeline_rank = self.config.num_layers // pp_size
216+
offset = pp_rank * num_layers_per_pipeline_rank
217+
num_layers_this_rank = num_layers_per_pipeline_rank
218+
219+
selected_list = layer_type_list[offset : offset + num_layers_this_rank]
182220

183221
return offset, selected_list
184222

0 commit comments

Comments
 (0)