Skip to content

Conversation

@khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jan 8, 2026

Description

This PR introduces support for Distributed Low-Communication (DiLoCo) training in MaxText. It implements both standard DiLoCo, enabling efficient model training across disjoint clusters ("islands") by synchronizing gradients infrequently via an outer optimizer.

Key Changes

  • Core Logic: Added src/MaxText/diloco.py, which implements the DiLoCoTrainState, inner/outer optimization steps,
    and communication synchronization using drjax.
  • Training Loop Integration: Modified src/MaxText/train.py to initialize the DiLoCo state and adapt the training
    step when enable_diloco is active. This includes handling data reshaping for multiple replicas.
  • Sharding & Configuration:
    • Updated src/MaxText/sharding.py to support a hierarchical "diloco" sharding axis.
    • Added new flags (e.g., enable_diloco, num_diloco_replicas, diloco_outer_optimizer) to base.yml and types.py.
  • Dependencies: Added drjax to the project requirements.
  • Testing: Added comprehensive unit tests in tests/diloco_test.py.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@khatwanimohit khatwanimohit changed the title Mohit/diloco trainer [Diloco] Diloco trainer Jan 8, 2026
@khatwanimohit khatwanimohit force-pushed the mohit/diloco_trainer branch 2 times, most recently from 364cf4e to cdba187 Compare January 8, 2026 18:13
@codecov
Copy link

codecov bot commented Jan 8, 2026

Codecov Report

❌ Patch coverage is 77.03704% with 31 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/diloco.py 83.33% 13 Missing and 2 partials ⚠️
src/MaxText/train_utils.py 22.22% 11 Missing and 3 partials ⚠️
src/MaxText/data_loader.py 60.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator Author

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add train_compile tests for Diloco

@khatwanimohit khatwanimohit force-pushed the mohit/diloco_trainer branch 5 times, most recently from 889ea25 to c607a2a Compare January 22, 2026 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants