3030from flax .linen .spmd import LogicallyPartitioned
3131
3232from MaxText .common_types import Config , MODEL_MODE_TRAIN , EP_AS_CONTEXT , ShardMode
33+ # from MaxText import maxtext_utils
3334from MaxText .sharding import (
3435 maybe_shard_with_logical ,
3536 maybe_shard_with_name ,
@@ -199,12 +200,17 @@ def init_states(self, inputs):
199200
200201 def _init_bsw_from_weights (variables ):
201202 """Buffer space for two copies of weights."""
202- return jax .tree .map (lambda x : jnp .zeros_like (x [:2 ]), variables )
203+ # take idx 0 slice assuming num_layers_per_pipeline_stage=1
204+ return (
205+ jax .tree .map (lambda x : jnp .zeros_like (x [0 ]), variables ),
206+ jax .tree .map (lambda x : jnp .zeros_like (x [0 ]), variables ),
207+ )
203208
204209 if self .is_initializing ():
205210 bsw = None
206211 else :
207- bsw = _init_bsw_from_weights (self .layers .variables )
212+ variables = self ._remove_logically_partition (self .layers .variables )
213+ bsw = _init_bsw_from_weights (variables )
208214
209215 init_loop_state = {
210216 "state_io" : state_io ,
@@ -264,6 +270,31 @@ def select_state_or_input(first_stage_in, shift):
264270 stages_in = self ._maybe_shard_with_logical (stages_in , self .stages_in_logical )
265271 return stages_in
266272
273+ def shard_dim_by_stages (self , x , dim : int , physical_partition_spec : P | None , is_stage_weight : bool = False ):
274+ """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
275+ the specified dimension."""
276+ # placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
277+ # if physical_partition_spec is None:
278+ # dims_mapping = [placeholder] * x.ndim
279+ # else:
280+ # physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
281+ # dims_mapping = list(physical_partition_spec)
282+ # # If not a stage weight, we handle the repeat dimension offset
283+ # if not is_stage_weight:
284+ # dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
285+ # dims_mapping[dim] = "stage"
286+ # dims_mapping = tuple(dims_mapping)
287+ # # We add reduced rule only when pspec is given for a stage weight
288+ # if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
289+ # batch_mesh_axis = ["data", "fsdp"]
290+ # reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
291+ # pspec = P(*dims_mapping, reduced=set(reduced_mark))
292+ # else:
293+ # pspec = P(*dims_mapping)
294+ # sharding = jax.sharding.NamedSharding(self.mesh, pspec)
295+ # return self._maybe_shard_with_name(x, sharding)
296+ return x
297+
267298 def get_microbatch_and_repeat_ids (self , loop_iteration ):
268299 """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
269300 non-circular"""
@@ -273,6 +304,14 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
273304 repeat_ids = microbatches_processed // self .config .num_pipeline_microbatches
274305 return microbatch_ids , repeat_ids
275306
307+ def get_microbatch_and_repeat_ids_for_bsw (self , loop_iteration ):
308+ """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
309+ non-circular"""
310+ raw_processed = loop_iteration - self .forwarding_delay * jnp .arange (self .num_stages )
311+ repeat_ids = raw_processed // self .config .num_pipeline_microbatches
312+ microbatch_ids = jnp .maximum (raw_processed , 0 ) % self .config .num_pipeline_microbatches
313+ return microbatch_ids , repeat_ids
314+
276315 def vmap_parallel_gather (
277316 self , weights , physical_partition_spec , repeat_ids , repeat_dim_in_weights , stages_dim_in_weights
278317 ):
@@ -295,9 +334,18 @@ def _gather_one(x, repeat_id):
295334 return jnp .squeeze (jax .lax .dynamic_slice_in_dim (x , repeat_id , 1 , repeat_dim_in_weights ), repeat_dim_in_weights )
296335
297336 gathered_weights_stage_dim = 0
337+ repeat_ids = self .shard_dim_by_stages (repeat_ids , 0 , physical_partition_spec = None )
338+ # num_repeats x num_stages x *param_dim
339+ weights = self .shard_dim_by_stages (
340+ weights , stages_dim_in_weights , physical_partition_spec = physical_partition_spec , is_stage_weight = False
341+ )
298342 stage_weights = jax .vmap (_gather_one , in_axes = (stages_dim_in_weights , 0 ), out_axes = gathered_weights_stage_dim )(
299343 weights , repeat_ids
300344 )
345+ # num_stages x *param_dim
346+ stage_weights = self .shard_dim_by_stages (
347+ stage_weights , gathered_weights_stage_dim , physical_partition_spec = physical_partition_spec , is_stage_weight = True
348+ )
301349 return stage_weights
302350
303351 def vmap_gather (self , xs , ids , ids_dim ):
@@ -321,8 +369,9 @@ def _gather_one(x, i):
321369 replicated_sharding = NamedSharding (self .mesh , P ())
322370 return x .at [idx ].get (out_sharding = replicated_sharding )
323371
372+ ids = self .shard_dim_by_stages (ids , 0 , physical_partition_spec = None )
324373 outs = jax .vmap (_gather_one , in_axes = (None , 0 ), out_axes = ids_dim )(xs , ids )
325- return outs
374+ return self . shard_dim_by_stages ( outs , 0 , physical_partition_spec = None )
326375
327376 def get_new_loop_state (self , output , loop_state ):
328377 """
@@ -466,20 +515,53 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
466515 For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
467516 for circular pipelines each stage grabs only the weights corresponding to the current repeat.
468517 """
518+ pipeline_weights = self ._remove_logically_partition (pipeline_weights )
469519 if self .config .num_pipeline_repeats > 1 :
470- return self .get_current_weights_from_bsw (bsw , loop_iteration , physical_partition_spec = physical_partition_spec )
471- else :
472- return pipeline_weights
520+ pipeline_weights = self .get_current_weights_from_bsw (
521+ bsw , loop_iteration , physical_partition_spec = physical_partition_spec
522+ )
523+ return pipeline_weights
473524
474- def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec = None ):
525+ def get_current_weights_from_bsw (self , bsw , loop_iteration , physical_partition_spec ):
475526 """Collect and gather weights from given bsw (buffer sliding window)"""
527+ bsw_pps = jax .tree .map (self ._remove_fsdp_from_physical_partition_spec , physical_partition_spec )
528+ _ , repeat_ids = self .get_microbatch_and_repeat_ids_for_bsw (loop_iteration )
529+ target_repeat_id = repeat_ids [0 ]
476530
477- def _get_bsw_idx (loop_iteration ):
478- _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
479- bsw_ids = (repeat_ids == repeat_ids [0 ]).astype (
480- jnp .int32
481- ) # For early repeats this might return true when it should be false
482- return bsw_ids
531+ # path = ("params", "mlp", "wi_0", "kernel")
532+ # path = ("params", "weights")
533+
534+ # jax.debug.print(
535+ # "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
536+ # "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
537+ # iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
538+ # bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
539+ # bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
540+ # )
541+
542+ @jax .shard_map (
543+ mesh = self .mesh ,
544+ in_specs = ((bsw_pps , bsw_pps ), P ("stage" )),
545+ out_specs = (bsw_pps ),
546+ check_vma = True ,
547+ )
548+ def select_weights_from_bsw (bsw , repeat_id ):
549+ weights = jax .tree .map (
550+ lambda x , y : jax .lax .select (repeat_id [0 ] == target_repeat_id , y , x ),
551+ bsw [0 ],
552+ bsw [1 ],
553+ )
554+ # jax.debug.print(
555+ # "Iteration: {iter} | "
556+ # "Selected weights mean for Stage {s} with repeat id {i}: {m}",
557+ # iter=loop_iteration,
558+ # s=jax.lax.axis_index("stage"),
559+ # m=maxtext_utils.get_nested_value(weights, path).mean(),
560+ # i=repeat_id[0],
561+ # )
562+ return weights
563+
564+ weights = select_weights_from_bsw (bsw , repeat_ids )
483565
484566 circular_metadata_params = {
485567 nn .PARTITION_NAME : "circular_repeats" ,
@@ -489,24 +571,10 @@ def _get_bsw_idx(loop_iteration):
489571 "optimizer_dims_mapping" : None ,
490572 }
491573 weights = meta .remove_axis (
492- bsw , 0 , circular_metadata_params
574+ weights , 0 , circular_metadata_params
493575 ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular
494576 # entry per stage.
495- weights = self ._remove_logically_partition (weights )
496577
497- def gather_weights_for_stages_in (w , spec = None ):
498- return self .vmap_parallel_gather (
499- w ,
500- repeat_ids = _get_bsw_idx (loop_iteration ),
501- repeat_dim_in_weights = 0 ,
502- stages_dim_in_weights = 1 ,
503- physical_partition_spec = spec ,
504- )
505-
506- if physical_partition_spec is None :
507- weights = jax .tree .map (gather_weights_for_stages_in , weights )
508- else :
509- weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
510578 return weights
511579
512580 @staticmethod
@@ -539,40 +607,50 @@ def find_fsdp(pspec):
539607
540608 return jax .tree .map (find_fsdp , physical_partition_spec )
541609
542- def bsw_all_gather_over_fsdp (self , bsw , physical_partition_spec , loop_iteration ):
610+ def bsw_all_gather_over_fsdp (self , weights , bsw , physical_partition_spec , loop_iteration ):
543611 """All gather bsw over fsdp mesh axis using shardmap."""
544- pps_no_fsdp = jax .tree .map (self ._remove_fsdp_from_physical_partition_spec , physical_partition_spec )
612+ bsw_pps = self ._generate_bsw_pps_from_pps (physical_partition_spec )
613+ repeat_weights_pps = jax .tree .map (lambda p : P (* p [1 :]), physical_partition_spec )
545614 fsdp_idx = self .get_fsdp_index_pytree (physical_partition_spec )
546615
547616 _ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration + 1 )
548617
618+ def gather_weights_for_stages_in (w , spec ):
619+ return self .vmap_parallel_gather (
620+ w , repeat_ids = repeat_ids , repeat_dim_in_weights = 0 , stages_dim_in_weights = 1 , physical_partition_spec = spec
621+ )
622+
623+ if physical_partition_spec is None :
624+ repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights )
625+ else :
626+ repeat_weights = jax .tree .map (gather_weights_for_stages_in , weights , physical_partition_spec )
627+
628+ circular_metadata_params = {
629+ nn .PARTITION_NAME : "circular_repeats" ,
630+ "sub_weight_split_dims_mapping" : (None ,),
631+ "is_initializing" : self .is_initializing (),
632+ "x_times" : self .config .num_pipeline_repeats ,
633+ "optimizer_dims_mapping" : None ,
634+ }
635+ repeat_weights = meta .remove_axis (repeat_weights , 0 , circular_metadata_params )
636+
549637 @jax .shard_map (
550638 mesh = self .mesh ,
551- in_specs = (physical_partition_spec , pps_no_fsdp , None , None ),
552- out_specs = pps_no_fsdp ,
639+ in_specs = (repeat_weights_pps , ( bsw_pps , bsw_pps ) , None ),
640+ out_specs = ( bsw_pps , bsw_pps ) ,
553641 check_vma = True ,
554642 )
555- def _all_gather_inner (variables , cur_bsw , repeat_idx , fsdp_idx ):
556- new_variables = jax .tree .map (
557- lambda x : jax .lax .dynamic_slice_in_dim (x , repeat_idx , 1 ),
558- variables ,
559- )
560-
643+ def _all_gather_inner (sharded_weights , cur_bsw , fsdp_idx ):
561644 def _all_gather_invariant (x , i ):
562645 if i >= 0 :
563- return all_gather_invariant (x , axis_name = "fsdp" , axis = i , tiled = True )
646+ return all_gather_invariant (x , axis_name = "fsdp" , axis = i - 1 , tiled = True )
564647 return x
565648
566- new_variables = jax .tree .map (_all_gather_invariant , new_variables , fsdp_idx )
567-
568- def shift_and_insert (bsw_leaf , new_leaf ):
569- updated_bsw = bsw_leaf .at [0 ].set (bsw_leaf [1 ])
570- updated_bsw = updated_bsw .at [1 ].set (jnp .squeeze (new_leaf , axis = 0 ))
571- return updated_bsw
649+ new_variables = jax .tree .map (_all_gather_invariant , sharded_weights , fsdp_idx )
572650
573- return jax . tree . map ( shift_and_insert , cur_bsw , new_variables )
651+ return ( cur_bsw [ 1 ] , new_variables )
574652
575- return _all_gather_inner (self . layers . variables , bsw , repeat_ids [ 0 ] , fsdp_idx )
653+ return _all_gather_inner (repeat_weights , bsw , fsdp_idx )
576654
577655 def get_vmap_func_for_init (self ):
578656 """This vmap func is used to initialize the weights only on init."""
@@ -643,7 +721,7 @@ def run_one_iteration(
643721 deterministic ,
644722 model_mode ,
645723 decoder_layer_instance ,
646- logical_partition_spec = None ,
724+ logical_partition_spec ,
647725 ):
648726 """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
649727 and update the loop state."""
@@ -806,6 +884,13 @@ def _remove_fsdp_from_physical_partition_spec(pps):
806884 return P (* new_spec )
807885 return pps
808886
887+ def _generate_bsw_pps_from_pps (self , physical_partition_spec ):
888+ """Create bsw physical partition spec from weight physical partition spec."""
889+ return jax .tree .map (
890+ lambda pps : P (* self ._remove_fsdp_from_physical_partition_spec (pps )[1 :]),
891+ physical_partition_spec ,
892+ )
893+
809894 @nn .compact
810895 def __call__ (
811896 self ,
@@ -961,8 +1046,9 @@ def run_iteration_scannable(model, loop_state):
9611046 )
9621047
9631048 def run_one_repeat_scannable (model , loop_state ):
1049+ weights = model ._remove_logically_partition (model .layers .variables ) # pylint: disable=protected-access
9641050 loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
965- loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
1051+ weights , loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
9661052 )
9671053
9681054 if model .config .scan_pipeline_iterations :
@@ -992,65 +1078,6 @@ def run_one_repeat_scannable(model, loop_state):
9921078 policy = self .get_pipeline_remat_policy (),
9931079 )
9941080
995- def run_real_repeats (model , loop_state ):
996- if self .config .scan_pipeline_repeats :
997- run_repeats_scanned = nn .scan (
998- run_one_repeat_scannable ,
999- variable_axes = {
1000- "summaries" : 0 ,
1001- "aux_loss" : 0 ,
1002- "intermediates" : 0 ,
1003- "hyper_params" : 0 ,
1004- },
1005- variable_broadcast = variable_broadcast ,
1006- variable_carry = variable_carry ,
1007- split_rngs = {"random" : True },
1008- length = model .config .num_pipeline_repeats ,
1009- )
1010- loop_state , _ = run_repeats_scanned (model , loop_state )
1011- else :
1012- for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
1013- loop_state , _ = run_one_repeat_scannable (model , loop_state )
1014- return loop_state
1015-
1016- run_real_repeats = nn .remat (
1017- run_real_repeats ,
1018- prevent_cse = not self .config .scan_pipeline_iterations ,
1019- policy = self .get_pipeline_remat_policy (),
1020- )
1021-
1022- def run_bubble_iterations_scannable (model , loop_state ):
1023- loop_state ["bsw" ] = model .bsw_all_gather_over_fsdp (
1024- loop_state ["bsw" ], physical_partition_spec , loop_state ["loop_iteration" ]
1025- )
1026-
1027- if model .config .scan_pipeline_iterations :
1028- run_one_repeat_scanned = nn .scan (
1029- run_iteration_scannable ,
1030- variable_axes = {
1031- "summaries" : 0 ,
1032- "aux_loss" : 0 ,
1033- "intermediates" : 0 ,
1034- "hyper_params" : 0 ,
1035- },
1036- variable_broadcast = variable_broadcast ,
1037- variable_carry = variable_carry ,
1038- # Dropout/aqt keys will be split for each iteration.
1039- split_rngs = {"random" : True },
1040- length = bubble_iterations ,
1041- )
1042- loop_state , _ = run_one_repeat_scanned (model , loop_state )
1043- else :
1044- for _ in range (model .config .num_pipeline_microbatches ):
1045- loop_state , _ = run_iteration_scannable (model , loop_state )
1046- return loop_state , None
1047-
1048- run_bubble_iterations_scannable = nn .remat (
1049- run_bubble_iterations_scannable ,
1050- prevent_cse = not self .config .scan_pipeline_iterations ,
1051- policy = self .get_pipeline_remat_policy (),
1052- )
1053-
10541081 def run_all_iterations (model , loop_state ):
10551082 if self .config .scan_pipeline_repeats :
10561083 run_repeats_scanned = nn .scan (
@@ -1068,7 +1095,7 @@ def run_all_iterations(model, loop_state):
10681095 )
10691096
10701097 run_bubbles_scanned = nn .scan (
1071- run_bubble_iterations_scannable ,
1098+ run_iteration_scannable ,
10721099 variable_axes = {
10731100 "summaries" : 0 ,
10741101 "aux_loss" : 0 ,
@@ -1078,9 +1105,10 @@ def run_all_iterations(model, loop_state):
10781105 variable_broadcast = variable_broadcast ,
10791106 variable_carry = variable_carry ,
10801107 split_rngs = {"random" : True },
1081- length = model . config . num_pipeline_repeats ,
1108+ length = bubble_iterations ,
10821109 )
10831110 loop_state , _ = run_repeats_scanned (model , loop_state )
1111+ loop_state ["bsw" ] = (loop_state ["bsw" ][1 ], jax .tree .map (jnp .zeros_like , loop_state ["bsw" ][1 ]))
10841112 loop_state , _ = run_bubbles_scanned (model , loop_state )
10851113 else :
10861114 for _ in range (model .config .num_pipeline_repeats ): # remat and scan outer loop
0 commit comments