Skip to content

Commit 3e87edf

Browse files
committed
diloco utils
1 parent dfa57f9 commit 3e87edf

File tree

9 files changed

+550
-1
lines changed

9 files changed

+550
-1
lines changed

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax
78
flax
89
gcsfs
910
google-api-python-client

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dill>=0.4.0
4040
distlib>=0.4.0
4141
dm-tree>=0.1.9
4242
docstring-parser>=0.17.0
43+
drjax>=0.1.4
4344
editdistance>=0.8.1
4445
einops>=0.8.1
4546
einshape>=1.0

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dill>=0.4.0
4040
distlib>=0.4.0
4141
dm-tree>=0.1.9
4242
docstring-parser>=0.17.0
43+
drjax>=0.1.4
4344
editdistance>=0.8.1
4445
einops>=0.8.1
4546
einshape>=1.0

dependencies/requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax>=0.1.4
78
flax
89
gcsfs
910
google-api-python-client
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Requirements for Building the MaxText Docker Image
2+
# These requirements are additional to the dependencies present in the JAX AI base image.
3+
datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip
4+
drjax>=0.1.4
5+
flax>=0.11.0
6+
google-api-python-client
7+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
8+
grain[parquet]>=0.2.13
9+
jaxtyping
10+
jsonlines
11+
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
12+
omegaconf
13+
orbax-checkpoint>=0.11.22
14+
pathwaysutils>=0.1.1
15+
pillow>=11.1.0
16+
pre-commit
17+
protobuf>=5.29.5
18+
pyink
19+
pylint
20+
pytest
21+
pytype
22+
qwix
23+
sentencepiece>=0.2.0
24+
tensorflow-datasets
25+
tensorflow-text>=2.17.0
26+
tiktoken
27+
tokamax>=0.0.4
28+
transformers

src/MaxText/configs/base.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
383383

384384
# Parallelism
385385
shard_mode: "auto" # can be either auto or explicit
386-
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
386+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
387387
logical_axis_rules: [
388388
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
389389
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
@@ -460,6 +460,7 @@ logical_axis_rules: [
460460
['paged_kv_head_dim_size', []],
461461
['dense_layers', []],
462462
['moe_layers', []],
463+
['diloco', 'diloco'],
463464
]
464465
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
465466
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
@@ -472,6 +473,7 @@ sharding_tolerance: 0.02
472473
# value to auto-shard based on available slices and devices.
473474
# By default, product of the DCN axes should equal number of slices
474475
# and product of the ICI axes should equal number of devices per slice.
476+
dcn_diloco_parallelism: 1
475477
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
476478
dcn_fsdp_parallelism: 1
477479
dcn_fsdp_transpose_parallelism: 1
@@ -484,6 +486,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
484486
dcn_pipeline_parallelism: 1
485487
dcn_expert_parallelism: 1
486488
dcn_autoregressive_parallelism: 1 # never recommended
489+
ici_diloco_parallelism: 1
487490
ici_data_parallelism: 1
488491
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
489492
ici_fsdp_transpose_parallelism: 1
@@ -696,6 +699,11 @@ enable_data_shuffling: True
696699
data_shuffle_seed: 0
697700
init_weights_seed: 0
698701

702+
# DiLoCo params.
703+
diloco_sync_period: 36
704+
diloco_outer_lr: 0.3
705+
diloco_outer_momentum: 0.9
706+
699707
# You may disable clipping by setting gradient_clipping_threshold to zero.
700708
gradient_clipping_threshold: 1.0
701709

src/MaxText/configs/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ class LayoutAndSharding(BaseModel):
733733
class DcnParallelism(BaseModel):
734734
"""Parallelism dimensions across the DCN (Data Center Network)."""
735735

736+
dcn_diloco_parallelism: int = Field(-1, description="DCN axis for Diloco parallelism.")
736737
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
737738
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
738739
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
@@ -752,6 +753,7 @@ class DcnParallelism(BaseModel):
752753
class IciParallelism(BaseModel):
753754
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
754755

756+
ici_diloco_parallelism: int = Field(-1, description="ICI axis for Diloco parallelism.")
755757
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
756758
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
757759
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
@@ -1000,6 +1002,14 @@ class TrainingLoop(BaseModel):
10001002
init_weights_seed: int = Field(0, description="Seed for model weight initialization.")
10011003

10021004

1005+
class DilocoParams(BaseModel):
1006+
"""Diloco Hyperparameters"""
1007+
1008+
diloco_sync_period: int = Field(36, description="Diloco sync period.")
1009+
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
1010+
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")
1011+
1012+
10031013
class Optimizer(BaseModel):
10041014
"""Configuration for the optimizer and learning rate schedule."""
10051015

@@ -1631,6 +1641,7 @@ class MaxTextConfig(
16311641
# Training, Optimization, and Fine-Tuning
16321642
RematAndOffload,
16331643
TrainingLoop,
1644+
DilocoParams,
16341645
Optimizer,
16351646
AdamW,
16361647
Muon,
@@ -2152,6 +2163,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21522163
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
21532164
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
21542165
self.ici_parallelism = [
2166+
self.ici_diloco_parallelism,
21552167
self.ici_pipeline_parallelism,
21562168
self.ici_data_parallelism,
21572169
self.ici_fsdp_parallelism,
@@ -2166,6 +2178,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21662178
self.ici_autoregressive_parallelism,
21672179
]
21682180
self.dcn_parallelism = [
2181+
self.dcn_diloco_parallelism,
21692182
self.dcn_pipeline_parallelism,
21702183
self.dcn_data_parallelism,
21712184
self.dcn_fsdp_parallelism,
@@ -2181,6 +2194,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21812194
]
21822195
else:
21832196
ici_map = {
2197+
"diloco": self.ici_diloco_parallelism,
21842198
"data": self.ici_data_parallelism,
21852199
"stage": self.ici_pipeline_parallelism,
21862200
"fsdp": self.ici_fsdp_parallelism,
@@ -2198,6 +2212,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21982212
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
21992213

22002214
dcn_map = {
2215+
"diloco": self.dcn_diloco_parallelism,
22012216
"data": self.dcn_data_parallelism,
22022217
"stage": self.dcn_pipeline_parallelism,
22032218
"fsdp": self.dcn_fsdp_parallelism,
@@ -2214,6 +2229,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22142229
}
22152230
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
22162231

2232+
# Diloco params
2233+
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)
2234+
22172235
# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
22182236
if isinstance(self.decoder_block, str):
22192237
self.decoder_block = DecoderBlockType(self.decoder_block.lower())

src/MaxText/diloco.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""An implementation of Distributed Low-Communication (DiLoCo) training.
16+
17+
This module contains implementations of:
18+
19+
- DiLoCo: Distributed Low-Communication Training of Language Models
20+
https://arxiv.org/abs/2311.08105
21+
- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch
22+
https://arxiv.org/abs/2501.18512
23+
"""
24+
25+
from collections.abc import Sequence
26+
from typing import Any, Callable
27+
28+
import drjax
29+
from flax import struct
30+
from flax.training import train_state
31+
import jax
32+
import jax.numpy as jnp
33+
from jaxtyping import Array, Int32, Key, PyTree, UInt32
34+
import optax
35+
36+
from MaxText import pyconfig
37+
38+
Batch = Any
39+
Params = PyTree
40+
Metrics = PyTree
41+
OptState = optax.OptState
42+
InnerOptStates = optax.OptState
43+
PRNGKey = Key[Array, ""] | UInt32[Array, "2"]
44+
Step = Int32[Array, ""]
45+
46+
47+
class DiLoCoTrainState(struct.PyTreeNode):
48+
"""The state of the DiLoCo training process.
49+
50+
Attributes:
51+
inner_state: A `flax.training.train_state.TrainState` of the state for each
52+
step of the inner optimization. All arrays are expected to have a leading
53+
dimension with size of the number of diloco replicas so that training
54+
steps can be mapped over this dimension.
55+
outer_params: A PyTree of the global model weights. These will mimic a
56+
sub-PyTree in `inner_state`, which rank-1 shape.
57+
outer_opt_state: The state for the outer Nesterov momentum optimizer.
58+
step: The step counter of the training process.
59+
"""
60+
61+
inner_state: train_state.TrainState
62+
outer_params: Params
63+
outer_opt_state: OptState
64+
step: Step
65+
66+
67+
def reshape_first_axis_with_diloco(num_diloco_replicas: int, pytree: PyTree) -> PyTree:
68+
"""Reshapes the first dimension of each array in the PyTree to include a DiLoCo axis.
69+
70+
This function takes a a batch of data represented as a PyTree
71+
and reshapes the leading dimension of each array within it. The purpose is
72+
to introduce a new 'diloco' axis, which is used for distributing data
73+
across DiLoCo replicas.
74+
75+
Args:
76+
num_diloco_replicas: The number of DiLoCo replicas. This determines the
77+
size of the new leading dimension.
78+
pytree: The input PyTree, where each array is expected to have a batch
79+
dimension as its first axis.
80+
81+
Returns:
82+
A new PyTree with the same structure as the input, but with each array's
83+
first dimension reshaped to `(num_diloco_replicas, original_batch_dim // num_diloco_replicas, ...)`.
84+
The sharding specification is also updated to include the 'diloco' axis.
85+
"""
86+
87+
def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str]] = ()) -> jax.sharding.PartitionSpec:
88+
if tuple(*pspec)[0] == "diloco":
89+
# pull out diloco axis if already present
90+
return jax.sharding.PartitionSpec("diloco", (*pspec[0][1:],), (*pspec[1:],))
91+
return jax.sharding.PartitionSpec("diloco", *pspec)
92+
93+
def reshape_for_diloco(arr):
94+
batch_dim, *example_shape = arr.shape
95+
diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape)
96+
s = arr.sharding
97+
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
98+
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)
99+
100+
return jax.tree.map(reshape_for_diloco, pytree)
101+
102+
103+
def build_diloco_state(
104+
config: "pyconfig.HyperParameters",
105+
initialize_state: Callable[[], train_state.TrainState],
106+
) -> tuple[DiLoCoTrainState, PyTree]:
107+
"""Given a non-DiLoCo train state, construct a DiLoCo training state."""
108+
outer_optimizer = optax.sgd(
109+
config.diloco_outer_lr,
110+
momentum=config.diloco_outer_momentum,
111+
nesterov=True,
112+
)
113+
114+
@drjax.program(placements={"diloco": config.num_diloco_replicas})
115+
def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
116+
state = initialize_state()
117+
# Inner state must be broadcast across clients.
118+
inner_state = drjax.broadcast(state)
119+
# Outer state retains a single copy of the model parameters and optimizer state.
120+
outer_params = state.params
121+
outer_opt_state = outer_optimizer.init(outer_params)
122+
outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state)
123+
return (
124+
DiLoCoTrainState(
125+
inner_state=inner_state, outer_params=outer_params, outer_opt_state=outer_opt_state, step=state.step
126+
),
127+
outer_opt_state_sharding,
128+
)
129+
130+
return init_diloco_state()
131+
132+
133+
def build_diloco_train_step(
134+
config: pyconfig.HyperParameters,
135+
train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]],
136+
) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]:
137+
"""Convert a local state and train step into DiLoCo-compatible versions.
138+
139+
This is an implementation of the original (non-streaming) DiLoCo algorithm
140+
which syncs all model parameters across the replicas every
141+
`config.diloco_sync_period` steps, treating the difference accumulated over
142+
non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on
143+
the "global" model.
144+
145+
Args:
146+
config: The config used to set up training.
147+
train_step: A local train step. This will be executed independently within
148+
each replica.
149+
"""
150+
outer_optimizer = optax.sgd(
151+
config.diloco_outer_lr,
152+
momentum=config.diloco_outer_momentum,
153+
nesterov=True,
154+
)
155+
156+
def synchronize(state):
157+
# Calculate the delta between the current replica's state and the global
158+
# state (since last synchronization).
159+
broadcast_outer_params = drjax.broadcast(state.outer_params)
160+
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
161+
# Treat the average delta as the outer optimizer's gradient and apply to
162+
# the global (outer) model params.
163+
averaged_pseudo_grad = drjax.reduce_mean(model_delta)
164+
updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params)
165+
new_outer_params = optax.apply_updates(state.outer_params, updates)
166+
# Replace inner model params with the new global model params.
167+
# NOTE: inner optimizer state is retained despite the change in parameters,
168+
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
169+
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
170+
return state.replace(
171+
outer_params=new_outer_params,
172+
outer_opt_state=new_opt_state,
173+
inner_state=new_inner_state,
174+
)
175+
176+
def typed_reduce_mean(in_tree):
177+
total = drjax.reduce_sum(in_tree)
178+
avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total)
179+
return avg
180+
181+
@drjax.program(placements={"diloco": config.num_diloco_replicas})
182+
def diloco_train_step(state, batch, prng):
183+
# Broadcast the RNG across replicas.
184+
broadcast_rng = drjax.broadcast(prng)
185+
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
186+
avg_metrics = typed_reduce_mean(metrics)
187+
state = state.replace(
188+
inner_state=inner_state,
189+
step=inner_state.step[0],
190+
)
191+
# Either synchronize the model, or no-op, depending on whether the current
192+
# step falls on the synchronization period.
193+
state = jax.lax.cond(
194+
inner_state.step[0] % config.diloco_sync_period == 0,
195+
synchronize,
196+
lambda x: x, # no-op
197+
state,
198+
)
199+
return state, avg_metrics
200+
201+
return diloco_train_step

0 commit comments

Comments
 (0)