Skip to content

Commit cbc4557

Browse files
committed
Add Warmup-Stable-Decay (WSD) learning rate scheduler with configurable stable and decay phases
Signed-off-by: bzantium <[email protected]>
1 parent 08216c6 commit cbc4557

File tree

4 files changed

+194
-23
lines changed

4 files changed

+194
-23
lines changed

src/MaxText/configs/base.yml

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ grain_file_type: 'arrayrecord' # arrayrecord or parquet
607607
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
608608
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
609609
grain_per_worker_buffer_size: 1
610-
# num_threads and prefetch_buffer_size are per-worker per-dataset.
610+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
611611
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
612612
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
613613
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
@@ -635,15 +635,29 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
635635
# However when run on google internal TPUs the coordination service is started automatically
636636
# and we should set this to True so we won't try to initialize a second time manually.
637637

638-
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
639-
# Learning rate schedule has either two or three parts:
638+
# Learning rate schedule structure depends on lr_schedule_type:
639+
#
640+
# Cosine schedule (lr_schedule_type='cosine'):
641+
# Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
642+
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
643+
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
644+
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
645+
#
646+
# WSD schedule (lr_schedule_type='wsd', Warmup-Stable-Decay):
640647
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
641-
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
642-
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
648+
# 2) Stable phase at [learning_rate] for the majority of training
649+
# 3) Decay from [learning_rate] to [learning_rate * wsd_learning_rate_final_fraction] over [learning_rate_schedule_steps * wsd_decay_steps_fraction] steps
650+
# The decay can be either linear or cosine based on wsd_decay_style
651+
# 4) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
652+
#
643653
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
644654
learning_rate: 3.e-5
645-
cosine_learning_rate_final_fraction: 0.1
646-
warmup_steps_fraction: 0.1
655+
lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd'
656+
cosine_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for cosine schedule
657+
wsd_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for WSD schedule
658+
wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%)
659+
wsd_decay_style: 'linear' # Decay style for WSD schedule: 'linear' or 'cosine'
660+
warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules)
647661
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
648662
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
649663
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

src/MaxText/configs/types.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ class OptimizerType(str, Enum):
124124
MUON = "muon"
125125

126126

127+
class LearningRateScheduleType(str, Enum):
128+
"""Supported learning rate schedule types."""
129+
130+
COSINE = "cosine"
131+
WSD = "wsd"
132+
133+
134+
class WsdDecayStyle(str, Enum):
135+
"""Supported decay styles for WSD schedule."""
136+
137+
LINEAR = "linear"
138+
COSINE = "cosine"
139+
140+
127141
class RopeType(str, Enum):
128142
"""Supported Rotary Positional Embedding (RoPE) implementations."""
129143

@@ -1005,9 +1019,21 @@ class Optimizer(BaseModel):
10051019
1.0, description="The threshold for gradient clipping. 0 disables clipping."
10061020
)
10071021
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
1022+
lr_schedule_type: LearningRateScheduleType = Field(
1023+
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
1024+
)
10081025
cosine_learning_rate_final_fraction: float = Field(
10091026
0.1, description="Final LR as a fraction of peak LR in cosine decay."
10101027
)
1028+
wsd_learning_rate_final_fraction: float = Field(
1029+
0.1, description="Final LR as a fraction of peak LR in WSD decay phase."
1030+
)
1031+
wsd_decay_steps_fraction: float = Field(
1032+
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
1033+
)
1034+
wsd_decay_style: WsdDecayStyle = Field(
1035+
WsdDecayStyle.LINEAR, description="The decay style for WSD schedule ('linear' or 'cosine')."
1036+
)
10111037
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
10121038
learning_rate_schedule_steps: int = Field(
10131039
-1,

src/MaxText/maxtext_utils.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from MaxText import max_utils
4141
from MaxText import multimodal_utils
4242
from MaxText import sharding
43+
from MaxText.configs import types
4344
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4445
from MaxText.inference.page_manager import PageState
4546

@@ -1103,16 +1104,25 @@ def create_device_mesh(config, devices=None):
11031104

11041105

11051106
def create_learning_rate_schedule(config):
1106-
"""Creates a warmup and cosine decay learning rate schedule:
1107-
We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1108-
Learning rate schedule has either two or three parts:
1107+
"""Creates a learning rate schedule with warmup and decay.
1108+
1109+
Supports two schedule types:
1110+
- Cosine: Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1111+
- WSD (Warmup-Stable-Decay): Maintains constant learning rate for most of training before final decay
1112+
1113+
Schedule structure:
11091114
1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
1110-
2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
1115+
2) Decay from [learning_rate] to a final value until learning_rate_schedule_steps
1116+
- Cosine: decays to [learning_rate * cosine_learning_rate_final_fraction]
1117+
- WSD: maintains [learning_rate] for a stable phase, then decays to [learning_rate * wsd_learning_rate_final_fraction]
1118+
using either linear or cosine decay based on wsd_decay_style
11111119
3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
11121120
The zero learning rate section can be used to more accurately measure the fully trained model's performance.
11131121
"""
11141122

11151123
def make_cos_schedule(init_lr, final_lr, len_steps):
1124+
"""Creates a cosine decay schedule from init_lr to final_lr over len_steps."""
1125+
11161126
def schedule(step):
11171127
pct = (step) / len_steps
11181128
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
@@ -1122,25 +1132,50 @@ def schedule(step):
11221132
return schedule
11231133

11241134
lr = config.learning_rate
1125-
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1126-
11271135
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
1128-
cos_steps = config.learning_rate_schedule_steps - warmup_steps
11291136
constant_zero_steps = config.steps - config.learning_rate_schedule_steps
1130-
11311137
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
1132-
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1133-
constant_schedule = optax.constant_schedule(0.0)
11341138

1135-
pieces = [warmup_schedule, cos_schedule]
1136-
boundaries = [
1137-
warmup_steps,
1138-
warmup_steps + cos_steps,
1139-
]
1139+
if config.lr_schedule_type == types.LearningRateScheduleType.COSINE:
1140+
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1141+
cos_steps = config.learning_rate_schedule_steps - warmup_steps
1142+
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1143+
1144+
pieces = [warmup_schedule, cos_schedule]
1145+
boundaries = [warmup_steps, warmup_steps + cos_steps]
1146+
1147+
elif config.lr_schedule_type == types.LearningRateScheduleType.WSD:
1148+
wsd_final_lr = lr * config.wsd_learning_rate_final_fraction
1149+
decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
1150+
stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
1151+
1152+
if stable_steps < 0:
1153+
raise ValueError(
1154+
f"Invalid WSD schedule: warmup_steps_fraction ({config.warmup_steps_fraction}) + "
1155+
f"wsd_decay_steps_fraction ({config.wsd_decay_steps_fraction}) must not exceed 1.0. "
1156+
f"Current sum: {config.warmup_steps_fraction + config.wsd_decay_steps_fraction}"
1157+
)
1158+
1159+
stable_schedule = optax.constant_schedule(lr)
1160+
1161+
# Create decay schedule based on wsd_decay_style
1162+
if config.wsd_decay_style == types.WSDDecayStyle.LINEAR:
1163+
decay_schedule = optax.linear_schedule(init_value=lr, end_value=wsd_final_lr, transition_steps=decay_steps)
1164+
elif config.wsd_decay_style == types.WSDDecayStyle.COSINE:
1165+
decay_schedule = make_cos_schedule(lr, wsd_final_lr, decay_steps)
1166+
else:
1167+
raise ValueError(f"Invalid wsd_decay_style: {config.wsd_decay_style}. " "Must be either 'linear' or 'cosine'.")
1168+
1169+
pieces = [warmup_schedule, stable_schedule, decay_schedule]
1170+
boundaries = [warmup_steps, warmup_steps + stable_steps]
1171+
1172+
else:
1173+
raise ValueError(f"Invalid lr_schedule_type: {config.lr_schedule_type}. " "Must be either 'cosine' or 'wsd'.")
11401174

11411175
if constant_zero_steps > 0:
1176+
constant_schedule = optax.constant_schedule(0.0)
11421177
pieces.append(constant_schedule)
1143-
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
1178+
boundaries.append(boundaries[-1] + constant_zero_steps)
11441179

11451180
return optax.join_schedules(pieces, boundaries)
11461181

tests/maxtext_utils_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,5 +682,101 @@ def test_bytes_from_pytree_empty_dict(self):
682682
self.assertEqual(max_utils.calculate_bytes_from_pytree({}), 0)
683683

684684

685+
class TestLearningRateSchedules(unittest.TestCase):
686+
"""Test suite for learning rate schedule functions."""
687+
688+
def test_cosine_schedule(self):
689+
"""Tests cosine learning rate schedule."""
690+
config = pyconfig.initialize(
691+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
692+
enable_checkpointing=False,
693+
learning_rate=1e-3,
694+
learning_rate_schedule_steps=1000,
695+
steps=1200,
696+
warmup_steps_fraction=0.1,
697+
lr_schedule_type="cosine",
698+
cosine_learning_rate_final_fraction=0.1,
699+
)
700+
701+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
702+
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
703+
704+
# Warmup phase: 0 -> peak
705+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
706+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), config.learning_rate, places=6)
707+
708+
# Cosine decay phase
709+
lr_end = schedule_fn(config.learning_rate_schedule_steps - 1)
710+
expected_final = config.learning_rate * config.cosine_learning_rate_final_fraction
711+
self.assertLess(float(lr_end), config.learning_rate)
712+
self.assertGreater(float(lr_end), expected_final * 0.9)
713+
714+
# Zero phase
715+
self.assertAlmostEqual(float(schedule_fn(config.steps - 1)), 0.0, places=6)
716+
717+
def test_wsd_schedule(self):
718+
"""Tests WSD learning rate schedule with both linear and cosine decay styles."""
719+
learning_rate = 1e-3
720+
learning_rate_schedule_steps = 1000
721+
steps = 1200
722+
warmup_steps_fraction = 0.1
723+
wsd_learning_rate_final_fraction = 0.1
724+
wsd_decay_steps_fraction = 0.1
725+
726+
warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction)
727+
decay_steps = int(learning_rate_schedule_steps * wsd_decay_steps_fraction)
728+
stable_steps = learning_rate_schedule_steps - warmup_steps - decay_steps
729+
decay_start = warmup_steps + stable_steps
730+
731+
# Test both decay styles: linear and cosine
732+
for decay_style in ["linear", "cosine"]:
733+
config = pyconfig.initialize(
734+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
735+
enable_checkpointing=False,
736+
learning_rate=learning_rate,
737+
learning_rate_schedule_steps=learning_rate_schedule_steps,
738+
steps=steps,
739+
warmup_steps_fraction=warmup_steps_fraction,
740+
lr_schedule_type="wsd",
741+
wsd_learning_rate_final_fraction=wsd_learning_rate_final_fraction,
742+
wsd_decay_steps_fraction=wsd_decay_steps_fraction,
743+
wsd_decay_style=decay_style,
744+
)
745+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
746+
747+
# Warmup phase: 0 -> peak
748+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
749+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
750+
751+
# Stable phase: constant at peak
752+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)
753+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + stable_steps // 2)), learning_rate, places=6)
754+
self.assertAlmostEqual(float(schedule_fn(decay_start - 1)), learning_rate, places=6)
755+
756+
# Decay phase: peak -> final
757+
lr_mid_decay = schedule_fn(decay_start + decay_steps // 2)
758+
expected_final = learning_rate * wsd_learning_rate_final_fraction
759+
self.assertLess(float(lr_mid_decay), learning_rate)
760+
self.assertGreater(float(lr_mid_decay), expected_final)
761+
762+
# Zero phase
763+
self.assertAlmostEqual(float(schedule_fn(steps - 1)), 0.0, places=6)
764+
765+
# Test invalid fractions
766+
config_invalid_fractions = pyconfig.initialize(
767+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
768+
enable_checkpointing=False,
769+
learning_rate=learning_rate,
770+
learning_rate_schedule_steps=learning_rate_schedule_steps,
771+
steps=steps,
772+
warmup_steps_fraction=0.6,
773+
lr_schedule_type="wsd",
774+
wsd_learning_rate_final_fraction=wsd_learning_rate_final_fraction,
775+
wsd_decay_steps_fraction=0.5, # Sum > 1.0
776+
)
777+
with self.assertRaises(ValueError):
778+
maxtext_utils.create_learning_rate_schedule(config_invalid_fractions)
779+
780+
685781
if __name__ == "__main__":
686782
unittest.main()

0 commit comments

Comments
 (0)