@@ -170,15 +170,53 @@ def __init__(
170170 )
171171
172172 def _select_layers_for_pipeline_parallel (self , layer_type_list ):
173- num_layers_per_pipeline_rank = self .config .num_layers // self .pp_group .size ()
174-
175173 assert self .config .virtual_pipeline_model_parallel_size is None , (
176174 "The Mamba hybrid model does not currently support "
177175 "virtual/interleaved pipeline parallelism"
178176 )
179177
180- offset = self .pp_group .rank () * num_layers_per_pipeline_rank
181- selected_list = layer_type_list [offset : offset + num_layers_per_pipeline_rank ]
178+ pp_rank = self .pp_group .rank ()
179+ pp_size = self .pp_group .size ()
180+
181+ num_layers_in_first = self .config .num_layers_in_first_pipeline_stage
182+ num_layers_in_last = self .config .num_layers_in_last_pipeline_stage
183+
184+ if num_layers_in_first is not None or num_layers_in_last is not None :
185+ # Uneven pipeline parallelism: mirror the logic in
186+ # get_transformer_layer_offset so that MambaStack and
187+ # TransformerLayer agree on layer placement.
188+ first = 0 if num_layers_in_first is None else num_layers_in_first
189+ last = 0 if num_layers_in_last is None else num_layers_in_last
190+ middle_num_layers = self .config .num_layers - first - last
191+
192+ middle_pipeline_stages = pp_size - sum (
193+ 1 for x in (num_layers_in_first , num_layers_in_last ) if x is not None
194+ )
195+
196+ if middle_pipeline_stages > 0 :
197+ layers_per_middle = middle_num_layers // middle_pipeline_stages
198+ else :
199+ layers_per_middle = 0
200+
201+ is_first_stage = num_layers_in_first is not None and pp_rank == 0
202+ is_last_stage = num_layers_in_last is not None and pp_rank == pp_size - 1
203+
204+ if is_first_stage :
205+ offset = 0
206+ num_layers_this_rank = first
207+ elif is_last_stage :
208+ offset = self .config .num_layers - last
209+ num_layers_this_rank = last
210+ else :
211+ middle_rank = pp_rank if num_layers_in_first is None else pp_rank - 1
212+ offset = middle_rank * layers_per_middle + first
213+ num_layers_this_rank = layers_per_middle
214+ else :
215+ num_layers_per_pipeline_rank = self .config .num_layers // pp_size
216+ offset = pp_rank * num_layers_per_pipeline_rank
217+ num_layers_this_rank = num_layers_per_pipeline_rank
218+
219+ selected_list = layer_type_list [offset : offset + num_layers_this_rank ]
182220
183221 return offset , selected_list
184222
0 commit comments