|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""An implementation of Distributed Low-Communication (DiLoCo) training. |
| 16 | +
|
| 17 | +This module contains implementations of: |
| 18 | +
|
| 19 | +- DiLoCo: Distributed Low-Communication Training of Language Models |
| 20 | + https://arxiv.org/abs/2311.08105 |
| 21 | +- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch |
| 22 | + https://arxiv.org/abs/2501.18512 |
| 23 | +""" |
| 24 | + |
| 25 | +from collections.abc import Sequence |
| 26 | +from typing import Any, Callable |
| 27 | + |
| 28 | +import drjax |
| 29 | +from flax import struct |
| 30 | +from flax.training import train_state |
| 31 | +import jax |
| 32 | +import jax.numpy as jnp |
| 33 | +from jaxtyping import Array, Int32, Key, PyTree, UInt32 |
| 34 | +import optax |
| 35 | + |
| 36 | +from MaxText import pyconfig |
| 37 | + |
| 38 | +Batch = Any |
| 39 | +Params = PyTree |
| 40 | +Metrics = PyTree |
| 41 | +OptState = optax.OptState |
| 42 | +InnerOptStates = optax.OptState |
| 43 | +PRNGKey = Key[Array, ""] | UInt32[Array, "2"] |
| 44 | +Step = Int32[Array, ""] |
| 45 | + |
| 46 | + |
| 47 | +class DiLoCoTrainState(struct.PyTreeNode): |
| 48 | + """The state of the DiLoCo training process. |
| 49 | +
|
| 50 | + Attributes: |
| 51 | + inner_state: A `flax.training.train_state.TrainState` of the state for each |
| 52 | + step of the inner optimization. All arrays are expected to have a leading |
| 53 | + dimension with size of the number of diloco replicas so that training |
| 54 | + steps can be mapped over this dimension. |
| 55 | + outer_params: A PyTree of the global model weights. These will mimic a |
| 56 | + sub-PyTree in `inner_state`, which rank-1 shape. |
| 57 | + outer_opt_state: The state for the outer Nesterov momentum optimizer. |
| 58 | + step: The step counter of the training process. |
| 59 | + """ |
| 60 | + |
| 61 | + inner_state: train_state.TrainState |
| 62 | + outer_params: Params |
| 63 | + outer_opt_state: OptState |
| 64 | + step: Step |
| 65 | + |
| 66 | + |
| 67 | +def reshape_first_axis_with_diloco(num_diloco_replicas: int, pytree: PyTree) -> PyTree: |
| 68 | + """Reshapes the first dimension of each array in the PyTree to include a DiLoCo axis. |
| 69 | +
|
| 70 | + This function takes a a batch of data represented as a PyTree |
| 71 | + and reshapes the leading dimension of each array within it. The purpose is |
| 72 | + to introduce a new 'diloco' axis, which is used for distributing data |
| 73 | + across DiLoCo replicas. |
| 74 | +
|
| 75 | + Args: |
| 76 | + num_diloco_replicas: The number of DiLoCo replicas. This determines the |
| 77 | + size of the new leading dimension. |
| 78 | + pytree: The input PyTree, where each array is expected to have a batch |
| 79 | + dimension as its first axis. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + A new PyTree with the same structure as the input, but with each array's |
| 83 | + first dimension reshaped to `(num_diloco_replicas, original_batch_dim // num_diloco_replicas, ...)`. |
| 84 | + The sharding specification is also updated to include the 'diloco' axis. |
| 85 | + """ |
| 86 | + |
| 87 | + def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str]] = ()) -> jax.sharding.PartitionSpec: |
| 88 | + if tuple(*pspec)[0] == "diloco": |
| 89 | + # pull out diloco axis if already present |
| 90 | + return jax.sharding.PartitionSpec("diloco", (*pspec[0][1:],), (*pspec[1:],)) |
| 91 | + return jax.sharding.PartitionSpec("diloco", *pspec) |
| 92 | + |
| 93 | + def reshape_for_diloco(arr): |
| 94 | + batch_dim, *example_shape = arr.shape |
| 95 | + diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape) |
| 96 | + s = arr.sharding |
| 97 | + s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec)) |
| 98 | + return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s) |
| 99 | + |
| 100 | + return jax.tree.map(reshape_for_diloco, pytree) |
| 101 | + |
| 102 | + |
| 103 | +def build_diloco_state( |
| 104 | + config: "pyconfig.HyperParameters", |
| 105 | + initialize_state: Callable[[], train_state.TrainState], |
| 106 | +) -> tuple[DiLoCoTrainState, PyTree]: |
| 107 | + """Given a non-DiLoCo train state, construct a DiLoCo training state.""" |
| 108 | + outer_optimizer = optax.sgd( |
| 109 | + config.diloco_outer_lr, |
| 110 | + momentum=config.diloco_outer_momentum, |
| 111 | + nesterov=True, |
| 112 | + ) |
| 113 | + |
| 114 | + @drjax.program(placements={"diloco": config.num_diloco_replicas}) |
| 115 | + def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: |
| 116 | + state = initialize_state() |
| 117 | + # Inner state must be broadcast across clients. |
| 118 | + inner_state = drjax.broadcast(state) |
| 119 | + # Outer state retains a single copy of the model parameters and optimizer state. |
| 120 | + outer_params = state.params |
| 121 | + outer_opt_state = outer_optimizer.init(outer_params) |
| 122 | + outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) |
| 123 | + return ( |
| 124 | + DiLoCoTrainState( |
| 125 | + inner_state=inner_state, outer_params=outer_params, outer_opt_state=outer_opt_state, step=state.step |
| 126 | + ), |
| 127 | + outer_opt_state_sharding, |
| 128 | + ) |
| 129 | + |
| 130 | + return init_diloco_state() |
| 131 | + |
| 132 | + |
| 133 | +def build_diloco_train_step( |
| 134 | + config: pyconfig.HyperParameters, |
| 135 | + train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]], |
| 136 | +) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]: |
| 137 | + """Convert a local state and train step into DiLoCo-compatible versions. |
| 138 | +
|
| 139 | + This is an implementation of the original (non-streaming) DiLoCo algorithm |
| 140 | + which syncs all model parameters across the replicas every |
| 141 | + `config.diloco_sync_period` steps, treating the difference accumulated over |
| 142 | + non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on |
| 143 | + the "global" model. |
| 144 | +
|
| 145 | + Args: |
| 146 | + config: The config used to set up training. |
| 147 | + train_step: A local train step. This will be executed independently within |
| 148 | + each replica. |
| 149 | + """ |
| 150 | + outer_optimizer = optax.sgd( |
| 151 | + config.diloco_outer_lr, |
| 152 | + momentum=config.diloco_outer_momentum, |
| 153 | + nesterov=True, |
| 154 | + ) |
| 155 | + |
| 156 | + def synchronize(state): |
| 157 | + # Calculate the delta between the current replica's state and the global |
| 158 | + # state (since last synchronization). |
| 159 | + broadcast_outer_params = drjax.broadcast(state.outer_params) |
| 160 | + model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) |
| 161 | + # Treat the average delta as the outer optimizer's gradient and apply to |
| 162 | + # the global (outer) model params. |
| 163 | + averaged_pseudo_grad = drjax.reduce_mean(model_delta) |
| 164 | + updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params) |
| 165 | + new_outer_params = optax.apply_updates(state.outer_params, updates) |
| 166 | + # Replace inner model params with the new global model params. |
| 167 | + # NOTE: inner optimizer state is retained despite the change in parameters, |
| 168 | + # see section 6.1 in https://arxiv.org/pdf/2311.08105. |
| 169 | + new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state) |
| 170 | + return state.replace( |
| 171 | + outer_params=new_outer_params, |
| 172 | + outer_opt_state=new_opt_state, |
| 173 | + inner_state=new_inner_state, |
| 174 | + ) |
| 175 | + |
| 176 | + def typed_reduce_mean(in_tree): |
| 177 | + total = drjax.reduce_sum(in_tree) |
| 178 | + avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total) |
| 179 | + return avg |
| 180 | + |
| 181 | + @drjax.program(placements={"diloco": config.num_diloco_replicas}) |
| 182 | + def diloco_train_step(state, batch, prng): |
| 183 | + # Broadcast the RNG across replicas. |
| 184 | + broadcast_rng = drjax.broadcast(prng) |
| 185 | + inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng)) |
| 186 | + avg_metrics = typed_reduce_mean(metrics) |
| 187 | + state = state.replace( |
| 188 | + inner_state=inner_state, |
| 189 | + step=inner_state.step[0], |
| 190 | + ) |
| 191 | + # Either synchronize the model, or no-op, depending on whether the current |
| 192 | + # step falls on the synchronization period. |
| 193 | + state = jax.lax.cond( |
| 194 | + inner_state.step[0] % config.diloco_sync_period == 0, |
| 195 | + synchronize, |
| 196 | + lambda x: x, # no-op |
| 197 | + state, |
| 198 | + ) |
| 199 | + return state, avg_metrics |
| 200 | + |
| 201 | + return diloco_train_step |
0 commit comments