Skip to content

Commit 5d6c552

Browse files
committed
working all gather insertion
1 parent 37c34f7 commit 5d6c552

File tree

2 files changed

+139
-111
lines changed

2 files changed

+139
-111
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ prometheus_port: 0
839839
enable_jax_profiler: False
840840
jax_profiler_port: 9999
841841

842-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
842+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
843843
debug_sharding: False # Prints model weights sharding info
844844

845845
# Checkpoint Structured logging

src/MaxText/layers/pipeline.py

Lines changed: 138 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from flax.linen.spmd import LogicallyPartitioned
3131

3232
from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode
33+
# from MaxText import maxtext_utils
3334
from MaxText.sharding import (
3435
maybe_shard_with_logical,
3536
maybe_shard_with_name,
@@ -199,12 +200,17 @@ def init_states(self, inputs):
199200

200201
def _init_bsw_from_weights(variables):
201202
"""Buffer space for two copies of weights."""
202-
return jax.tree.map(lambda x: jnp.zeros_like(x[:2]), variables)
203+
# take idx 0 slice assuming num_layers_per_pipeline_stage=1
204+
return (
205+
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
206+
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
207+
)
203208

204209
if self.is_initializing():
205210
bsw = None
206211
else:
207-
bsw = _init_bsw_from_weights(self.layers.variables)
212+
variables = self._remove_logically_partition(self.layers.variables)
213+
bsw = _init_bsw_from_weights(variables)
208214

209215
init_loop_state = {
210216
"state_io": state_io,
@@ -264,6 +270,31 @@ def select_state_or_input(first_stage_in, shift):
264270
stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical)
265271
return stages_in
266272

273+
def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False):
274+
"""Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at
275+
the specified dimension."""
276+
# placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED
277+
# if physical_partition_spec is None:
278+
# dims_mapping = [placeholder] * x.ndim
279+
# else:
280+
# physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec)
281+
# dims_mapping = list(physical_partition_spec)
282+
# # If not a stage weight, we handle the repeat dimension offset
283+
# if not is_stage_weight:
284+
# dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats
285+
# dims_mapping[dim] = "stage"
286+
# dims_mapping = tuple(dims_mapping)
287+
# # We add reduced rule only when pspec is given for a stage weight
288+
# if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT:
289+
# batch_mesh_axis = ["data", "fsdp"]
290+
# reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1]
291+
# pspec = P(*dims_mapping, reduced=set(reduced_mark))
292+
# else:
293+
# pspec = P(*dims_mapping)
294+
# sharding = jax.sharding.NamedSharding(self.mesh, pspec)
295+
# return self._maybe_shard_with_name(x, sharding)
296+
return x
297+
267298
def get_microbatch_and_repeat_ids(self, loop_iteration):
268299
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
269300
non-circular"""
@@ -273,6 +304,14 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
273304
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
274305
return microbatch_ids, repeat_ids
275306

307+
def get_microbatch_and_repeat_ids_for_bsw(self, loop_iteration):
308+
"""Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and
309+
non-circular"""
310+
raw_processed = loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages)
311+
repeat_ids = raw_processed // self.config.num_pipeline_microbatches
312+
microbatch_ids = jnp.maximum(raw_processed, 0) % self.config.num_pipeline_microbatches
313+
return microbatch_ids, repeat_ids
314+
276315
def vmap_parallel_gather(
277316
self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights
278317
):
@@ -295,9 +334,18 @@ def _gather_one(x, repeat_id):
295334
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights)
296335

297336
gathered_weights_stage_dim = 0
337+
repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None)
338+
# num_repeats x num_stages x *param_dim
339+
weights = self.shard_dim_by_stages(
340+
weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False
341+
)
298342
stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(
299343
weights, repeat_ids
300344
)
345+
# num_stages x *param_dim
346+
stage_weights = self.shard_dim_by_stages(
347+
stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True
348+
)
301349
return stage_weights
302350

303351
def vmap_gather(self, xs, ids, ids_dim):
@@ -321,8 +369,9 @@ def _gather_one(x, i):
321369
replicated_sharding = NamedSharding(self.mesh, P())
322370
return x.at[idx].get(out_sharding=replicated_sharding)
323371

372+
ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
324373
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
325-
return outs
374+
return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None)
326375

327376
def get_new_loop_state(self, output, loop_state):
328377
"""
@@ -466,20 +515,53 @@ def get_current_stage_weights(self, pipeline_weights, bsw, loop_iteration, physi
466515
For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However
467516
for circular pipelines each stage grabs only the weights corresponding to the current repeat.
468517
"""
518+
pipeline_weights = self._remove_logically_partition(pipeline_weights)
469519
if self.config.num_pipeline_repeats > 1:
470-
return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec=physical_partition_spec)
471-
else:
472-
return pipeline_weights
520+
pipeline_weights = self.get_current_weights_from_bsw(
521+
bsw, loop_iteration, physical_partition_spec=physical_partition_spec
522+
)
523+
return pipeline_weights
473524

474-
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec=None):
525+
def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec):
475526
"""Collect and gather weights from given bsw (buffer sliding window)"""
527+
bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
528+
_, repeat_ids = self.get_microbatch_and_repeat_ids_for_bsw(loop_iteration)
529+
target_repeat_id = repeat_ids[0]
476530

477-
def _get_bsw_idx(loop_iteration):
478-
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
479-
bsw_ids = (repeat_ids == repeat_ids[0]).astype(
480-
jnp.int32
481-
) # For early repeats this might return true when it should be false
482-
return bsw_ids
531+
# path = ("params", "mlp", "wi_0", "kernel")
532+
# path = ("params", "weights")
533+
534+
# jax.debug.print(
535+
# "Iteration: {iter} | Global Target Repeat ID: {target} | Repeat_ids: {rids} | "
536+
# "BSW[0] per-stage means: {bsw0} | BSW[1] per-stage means: {bsw1}",
537+
# iter=loop_iteration, target=target_repeat_id, rids=repeat_ids,
538+
# bsw0=maxtext_utils.get_nested_value(bsw[0], path).mean(axis=(1, 2)),
539+
# bsw1=maxtext_utils.get_nested_value(bsw[1], path).mean(axis=(1, 2)),
540+
# )
541+
542+
@jax.shard_map(
543+
mesh=self.mesh,
544+
in_specs=((bsw_pps, bsw_pps), P("stage")),
545+
out_specs=(bsw_pps),
546+
check_vma=True,
547+
)
548+
def select_weights_from_bsw(bsw, repeat_id):
549+
weights = jax.tree.map(
550+
lambda x, y: jax.lax.select(repeat_id[0] == target_repeat_id, y, x),
551+
bsw[0],
552+
bsw[1],
553+
)
554+
# jax.debug.print(
555+
# "Iteration: {iter} | "
556+
# "Selected weights mean for Stage {s} with repeat id {i}: {m}",
557+
# iter=loop_iteration,
558+
# s=jax.lax.axis_index("stage"),
559+
# m=maxtext_utils.get_nested_value(weights, path).mean(),
560+
# i=repeat_id[0],
561+
# )
562+
return weights
563+
564+
weights = select_weights_from_bsw(bsw, repeat_ids)
483565

484566
circular_metadata_params = {
485567
nn.PARTITION_NAME: "circular_repeats",
@@ -489,24 +571,10 @@ def _get_bsw_idx(loop_iteration):
489571
"optimizer_dims_mapping": None,
490572
}
491573
weights = meta.remove_axis(
492-
bsw, 0, circular_metadata_params
574+
weights, 0, circular_metadata_params
493575
) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular
494576
# entry per stage.
495-
weights = self._remove_logically_partition(weights)
496577

497-
def gather_weights_for_stages_in(w, spec=None):
498-
return self.vmap_parallel_gather(
499-
w,
500-
repeat_ids=_get_bsw_idx(loop_iteration),
501-
repeat_dim_in_weights=0,
502-
stages_dim_in_weights=1,
503-
physical_partition_spec=spec,
504-
)
505-
506-
if physical_partition_spec is None:
507-
weights = jax.tree.map(gather_weights_for_stages_in, weights)
508-
else:
509-
weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
510578
return weights
511579

512580
@staticmethod
@@ -539,40 +607,50 @@ def find_fsdp(pspec):
539607

540608
return jax.tree.map(find_fsdp, physical_partition_spec)
541609

542-
def bsw_all_gather_over_fsdp(self, bsw, physical_partition_spec, loop_iteration):
610+
def bsw_all_gather_over_fsdp(self, weights, bsw, physical_partition_spec, loop_iteration):
543611
"""All gather bsw over fsdp mesh axis using shardmap."""
544-
pps_no_fsdp = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec)
612+
bsw_pps = self._generate_bsw_pps_from_pps(physical_partition_spec)
613+
repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec)
545614
fsdp_idx = self.get_fsdp_index_pytree(physical_partition_spec)
546615

547616
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration + 1)
548617

618+
def gather_weights_for_stages_in(w, spec):
619+
return self.vmap_parallel_gather(
620+
w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec
621+
)
622+
623+
if physical_partition_spec is None:
624+
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights)
625+
else:
626+
repeat_weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec)
627+
628+
circular_metadata_params = {
629+
nn.PARTITION_NAME: "circular_repeats",
630+
"sub_weight_split_dims_mapping": (None,),
631+
"is_initializing": self.is_initializing(),
632+
"x_times": self.config.num_pipeline_repeats,
633+
"optimizer_dims_mapping": None,
634+
}
635+
repeat_weights = meta.remove_axis(repeat_weights, 0, circular_metadata_params)
636+
549637
@jax.shard_map(
550638
mesh=self.mesh,
551-
in_specs=(physical_partition_spec, pps_no_fsdp, None, None),
552-
out_specs=pps_no_fsdp,
639+
in_specs=(repeat_weights_pps, (bsw_pps, bsw_pps), None),
640+
out_specs=(bsw_pps, bsw_pps),
553641
check_vma=True,
554642
)
555-
def _all_gather_inner(variables, cur_bsw, repeat_idx, fsdp_idx):
556-
new_variables = jax.tree.map(
557-
lambda x: jax.lax.dynamic_slice_in_dim(x, repeat_idx, 1),
558-
variables,
559-
)
560-
643+
def _all_gather_inner(sharded_weights, cur_bsw, fsdp_idx):
561644
def _all_gather_invariant(x, i):
562645
if i >= 0:
563-
return all_gather_invariant(x, axis_name="fsdp", axis=i, tiled=True)
646+
return all_gather_invariant(x, axis_name="fsdp", axis=i - 1, tiled=True)
564647
return x
565648

566-
new_variables = jax.tree.map(_all_gather_invariant, new_variables, fsdp_idx)
567-
568-
def shift_and_insert(bsw_leaf, new_leaf):
569-
updated_bsw = bsw_leaf.at[0].set(bsw_leaf[1])
570-
updated_bsw = updated_bsw.at[1].set(jnp.squeeze(new_leaf, axis=0))
571-
return updated_bsw
649+
new_variables = jax.tree.map(_all_gather_invariant, sharded_weights, fsdp_idx)
572650

573-
return jax.tree.map(shift_and_insert, cur_bsw, new_variables)
651+
return (cur_bsw[1], new_variables)
574652

575-
return _all_gather_inner(self.layers.variables, bsw, repeat_ids[0], fsdp_idx)
653+
return _all_gather_inner(repeat_weights, bsw, fsdp_idx)
576654

577655
def get_vmap_func_for_init(self):
578656
"""This vmap func is used to initialize the weights only on init."""
@@ -643,7 +721,7 @@ def run_one_iteration(
643721
deterministic,
644722
model_mode,
645723
decoder_layer_instance,
646-
logical_partition_spec=None,
724+
logical_partition_spec,
647725
):
648726
"""Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel,
649727
and update the loop state."""
@@ -806,6 +884,13 @@ def _remove_fsdp_from_physical_partition_spec(pps):
806884
return P(*new_spec)
807885
return pps
808886

887+
def _generate_bsw_pps_from_pps(self, physical_partition_spec):
888+
"""Create bsw physical partition spec from weight physical partition spec."""
889+
return jax.tree.map(
890+
lambda pps: P(*self._remove_fsdp_from_physical_partition_spec(pps)[1:]),
891+
physical_partition_spec,
892+
)
893+
809894
@nn.compact
810895
def __call__(
811896
self,
@@ -961,8 +1046,9 @@ def run_iteration_scannable(model, loop_state):
9611046
)
9621047

9631048
def run_one_repeat_scannable(model, loop_state):
1049+
weights = model._remove_logically_partition(model.layers.variables) # pylint: disable=protected-access
9641050
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
965-
loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
1051+
weights, loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
9661052
)
9671053

9681054
if model.config.scan_pipeline_iterations:
@@ -992,65 +1078,6 @@ def run_one_repeat_scannable(model, loop_state):
9921078
policy=self.get_pipeline_remat_policy(),
9931079
)
9941080

995-
def run_real_repeats(model, loop_state):
996-
if self.config.scan_pipeline_repeats:
997-
run_repeats_scanned = nn.scan(
998-
run_one_repeat_scannable,
999-
variable_axes={
1000-
"summaries": 0,
1001-
"aux_loss": 0,
1002-
"intermediates": 0,
1003-
"hyper_params": 0,
1004-
},
1005-
variable_broadcast=variable_broadcast,
1006-
variable_carry=variable_carry,
1007-
split_rngs={"random": True},
1008-
length=model.config.num_pipeline_repeats,
1009-
)
1010-
loop_state, _ = run_repeats_scanned(model, loop_state)
1011-
else:
1012-
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop
1013-
loop_state, _ = run_one_repeat_scannable(model, loop_state)
1014-
return loop_state
1015-
1016-
run_real_repeats = nn.remat(
1017-
run_real_repeats,
1018-
prevent_cse=not self.config.scan_pipeline_iterations,
1019-
policy=self.get_pipeline_remat_policy(),
1020-
)
1021-
1022-
def run_bubble_iterations_scannable(model, loop_state):
1023-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
1024-
loop_state["bsw"], physical_partition_spec, loop_state["loop_iteration"]
1025-
)
1026-
1027-
if model.config.scan_pipeline_iterations:
1028-
run_one_repeat_scanned = nn.scan(
1029-
run_iteration_scannable,
1030-
variable_axes={
1031-
"summaries": 0,
1032-
"aux_loss": 0,
1033-
"intermediates": 0,
1034-
"hyper_params": 0,
1035-
},
1036-
variable_broadcast=variable_broadcast,
1037-
variable_carry=variable_carry,
1038-
# Dropout/aqt keys will be split for each iteration.
1039-
split_rngs={"random": True},
1040-
length=bubble_iterations,
1041-
)
1042-
loop_state, _ = run_one_repeat_scanned(model, loop_state)
1043-
else:
1044-
for _ in range(model.config.num_pipeline_microbatches):
1045-
loop_state, _ = run_iteration_scannable(model, loop_state)
1046-
return loop_state, None
1047-
1048-
run_bubble_iterations_scannable = nn.remat(
1049-
run_bubble_iterations_scannable,
1050-
prevent_cse=not self.config.scan_pipeline_iterations,
1051-
policy=self.get_pipeline_remat_policy(),
1052-
)
1053-
10541081
def run_all_iterations(model, loop_state):
10551082
if self.config.scan_pipeline_repeats:
10561083
run_repeats_scanned = nn.scan(
@@ -1068,7 +1095,7 @@ def run_all_iterations(model, loop_state):
10681095
)
10691096

10701097
run_bubbles_scanned = nn.scan(
1071-
run_bubble_iterations_scannable,
1098+
run_iteration_scannable,
10721099
variable_axes={
10731100
"summaries": 0,
10741101
"aux_loss": 0,
@@ -1078,9 +1105,10 @@ def run_all_iterations(model, loop_state):
10781105
variable_broadcast=variable_broadcast,
10791106
variable_carry=variable_carry,
10801107
split_rngs={"random": True},
1081-
length=model.config.num_pipeline_repeats,
1108+
length=bubble_iterations,
10821109
)
10831110
loop_state, _ = run_repeats_scanned(model, loop_state)
1111+
loop_state["bsw"] = (loop_state["bsw"][1], jax.tree.map(jnp.zeros_like, loop_state["bsw"][1]))
10841112
loop_state, _ = run_bubbles_scanned(model, loop_state)
10851113
else:
10861114
for _ in range(model.config.num_pipeline_repeats): # remat and scan outer loop

0 commit comments

Comments
 (0)