-
Notifications
You must be signed in to change notification settings - Fork 456
Implement Warmup-Stable-Decay (WSD) Learning Rate Schedule #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Warmup-Stable-Decay (WSD) Learning Rate Schedule #2883
Conversation
0bdb5ea to
295c238
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
9a359f0 to
76b8800
Compare
gagika
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I have a few minor comments.
76b8800 to
cbc4557
Compare
|
@gagika Thanks for the review! I've updated the code to use types instead of strings. |
d560850 to
acbec6e
Compare
|
One integration test failed with the error below but this seems unrelated to my changes. |
Seems like a transient issue, I retrigerred the tests on your PR |
gagika
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for such a detailed and thoughtful PR. I have just a few small comments. Really appreciate the contribution!
bd8ee26 to
82da545
Compare
|
Thanks for the detailed review, @A9isha! I've addressed all your comments:
In addition, I implemented two logic fixes:
def make_cos_schedule(init_lr, final_lr, len_steps):
def schedule(step):
- pct = (step) / len_steps
+ pct = step / (len_steps - 1) if len_steps > 1 else 1.0
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
lr = init_lr * a + final_lr * (1 - a)
return lr
return scheduleThis ensures we reach the exact final learning rate at the last step, since
|
Thank you for the changes! let us keep |
…le stable and decay phases Signed-off-by: bzantium <ryumin93@gmail.com>
82da545 to
e886dd2
Compare
|
@A9isha That makes sense! I've reverted the change to keep |
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
1137c42
into
AI-Hypercomputer:main
Description
This PR implements the Warmup-Stable-Decay (WSD) learning rate schedule as a configurable option alongside the existing Cosine schedule. This allows users to choose between the standard cosine decay and a schedule that maintains a stable peak learning rate for the majority of training before a rapid decay.
Additionally, this implementation introduces a
wsd_decay_styleparameter, giving users the flexibility to choose the decay profile (linear or cosine) for the final annealing phase.Details and Context:
src/MaxText/configs/base.yml):lr_schedule_type(options:'cosine','wsd').wsd_learning_rate_final_fraction,wsd_decay_steps_fraction.wsd_decay_style: Supports'linear'(default, standard for WSD) or'cosine'decay for the final phase.src/MaxText/configs/types.py):LearningRateScheduleTypeandWsdDecayStyleEnums.Optimizerclass to include validation for these new fields.src/MaxText/maxtext_utils.py):create_learning_rate_scheduleto switch between Cosine and WSD logic.Linear Warmup->Constant Stable->Decay.optax.linear_scheduleand a custom cosine schedule based onwsd_decay_style.warmup_steps_fraction + wsd_decay_steps_fraction <= 1.0.Tests
I have added a comprehensive test suite,
TestLearningRateSchedules, intests/maxtext_utils_test.py.linearandcosinedecay styles.ValueError.To reproduce/test:
Fixes: #2882
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.