Skip to content

Conversation

@bzantium
Copy link
Collaborator

@bzantium bzantium commented Dec 24, 2025

Description

This PR implements the Warmup-Stable-Decay (WSD) learning rate schedule as a configurable option alongside the existing Cosine schedule. This allows users to choose between the standard cosine decay and a schedule that maintains a stable peak learning rate for the majority of training before a rapid decay.

Additionally, this implementation introduces a wsd_decay_style parameter, giving users the flexibility to choose the decay profile (linear or cosine) for the final annealing phase.

Details and Context:

  • Why: The WSD schedule is a widely adopted training strategy where the learning rate warms up, stays constant (stable) to maximize training throughput, and then decays rapidly to converge. Separating the stable and decay phases allows for "infinite" training horizons and flexible checkpointing.
  • Implementation:
    • Configuration (src/MaxText/configs/base.yml):
      • Added lr_schedule_type (options: 'cosine', 'wsd').
      • Added WSD-specific parameters: wsd_learning_rate_final_fraction, wsd_decay_steps_fraction.
      • Added wsd_decay_style: Supports 'linear' (default, standard for WSD) or 'cosine' decay for the final phase.
    • Types (src/MaxText/configs/types.py):
      • Added LearningRateScheduleType and WsdDecayStyle Enums.
      • Updated the Optimizer class to include validation for these new fields.
    • Logic (src/MaxText/maxtext_utils.py):
      • Refactored create_learning_rate_schedule to switch between Cosine and WSD logic.
      • Implemented WSD construction: Linear Warmup -> Constant Stable -> Decay.
      • The decay phase dynamically selects between optax.linear_schedule and a custom cosine schedule based on wsd_decay_style.
      • Added validation to ensure warmup_steps_fraction + wsd_decay_steps_fraction <= 1.0.

Tests

I have added a comprehensive test suite, TestLearningRateSchedules, in tests/maxtext_utils_test.py.

  • Unit Tests:
    • Cosine Schedule: Verified standard behavior (Warmup -> Cosine Decay).
    • WSD Schedule: Verified the 3-phase structure (Warmup -> Stable -> Decay) for both linear and cosine decay styles.
    • Checked that the learning rate hits the correct peak, stable values, and final fraction values.
  • Edge Cases: Verified that invalid configurations (e.g., sum of fractions > 1.0) raise a ValueError.

To reproduce/test:

python3 -m unittest tests/maxtext_utils_test.py

Fixes: #2882

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Dec 24, 2025

Codecov Report

❌ Patch coverage is 0% with 21 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/maxtext_utils.py 0.00% 21 Missing ⚠️

📢 Thoughts on this report? Let us know!

@bzantium bzantium force-pushed the feature/#2882 branch 3 times, most recently from 9a359f0 to 76b8800 Compare December 26, 2025 02:08
Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I have a few minor comments.

@bzantium
Copy link
Collaborator Author

bzantium commented Jan 7, 2026

@gagika Thanks for the review! I've updated the code to use types instead of strings.

@bzantium bzantium force-pushed the feature/#2882 branch 2 times, most recently from d560850 to acbec6e Compare January 7, 2026 23:29
@bzantium
Copy link
Collaborator Author

bzantium commented Jan 8, 2026

One integration test failed with the error below but this seems unrelated to my changes.

FAILED tests/integration_tests/generate_param_only_checkpoint_test.py::test_param_ckpt_generation_with_autoselected_attention[int8] - FileNotFoundError: Checkpoint at gs://runner-maxtext-logs/runner_2026-01-07-23-32-33/checkpoints/0/items not found.

@khatwanimohit
Copy link
Collaborator

One integration test failed with the error below but this seems unrelated to my changes.

FAILED tests/integration_tests/generate_param_only_checkpoint_test.py::test_param_ckpt_generation_with_autoselected_attention[int8] - FileNotFoundError: Checkpoint at gs://runner-maxtext-logs/runner_2026-01-07-23-32-33/checkpoints/0/items not found.

Seems like a transient issue, I retrigerred the tests on your PR

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for such a detailed and thoughtful PR. I have just a few small comments. Really appreciate the contribution!

@bzantium bzantium force-pushed the feature/#2882 branch 12 times, most recently from bd8ee26 to 82da545 Compare January 9, 2026 09:00
@bzantium
Copy link
Collaborator Author

bzantium commented Jan 9, 2026

Thanks for the detailed review, @A9isha! I've addressed all your comments:

  • Consolidated the specific final fractions into a single learning_rate_final_fraction.
  • Moved the validation/error checks into types.py (instead of pyconfig.py) to clean up the logic.
  • Added assertions to verify the final learning rate value at the very last step for both schedule types.

In addition, I implemented two logic fixes:

  1. Cosine Schedule: I adjusted the denominator to len_steps - 1.
  def make_cos_schedule(init_lr, final_lr, len_steps):
    def schedule(step):
-     pct = (step) / len_steps
+     pct = step / (len_steps - 1) if len_steps > 1 else 1.0
      a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
      lr = init_lr * a + final_lr * (1 - a)
      return lr

    return schedule

This ensures we reach the exact final learning rate at the last step, since step iterates from 0 to len_steps - 1.

  1. Warmup Schedule: I changed the starting value to lr / warmup_steps instead of 0. If it starts at 0, the model effectively receives no update during the first step.

@A9isha
Copy link
Collaborator

A9isha commented Jan 13, 2026

2. lr / warmup_steps

Thank you for the changes!

let us keep init_value = 0.0 - This is a common pattern in many JAX/Flax training loops where state.step starts at 0 and the schedule is queried before the first update

…le stable and decay phases

Signed-off-by: bzantium <ryumin93@gmail.com>
@bzantium
Copy link
Collaborator Author

bzantium commented Jan 13, 2026

@A9isha That makes sense! I've reverted the change to keep init_value = 0.0.

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@copybara-service copybara-service bot merged commit 1137c42 into AI-Hypercomputer:main Jan 20, 2026
23 of 24 checks passed
@bzantium bzantium deleted the feature/#2882 branch January 21, 2026 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement Warmup-Stable-Decay (WSD) Learning Rate Scheduler

5 participants