Skip to content

Commit 295c238

Browse files
committed
Add Warmup-Stable-Decay (WSD) learning rate scheduler with configurable stable and decay phases
Signed-off-by: bzantium <[email protected]>
1 parent 08216c6 commit 295c238

File tree

4 files changed

+163
-31
lines changed

4 files changed

+163
-31
lines changed

src/MaxText/configs/base.yml

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ grain_file_type: 'arrayrecord' # arrayrecord or parquet
607607
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
608608
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
609609
grain_per_worker_buffer_size: 1
610-
# num_threads and prefetch_buffer_size are per-worker per-dataset.
610+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
611611
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
612612
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
613613
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
@@ -635,15 +635,27 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
635635
# However when run on google internal TPUs the coordination service is started automatically
636636
# and we should set this to True so we won't try to initialize a second time manually.
637637

638-
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
639-
# Learning rate schedule has either two or three parts:
638+
# Learning rate schedule structure depends on lr_schedule_type:
639+
#
640+
# Cosine schedule (lr_schedule_type='cosine'):
641+
# Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
642+
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
643+
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
644+
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
645+
#
646+
# WSD schedule (lr_schedule_type='wsd', Warmup-Stable-Decay):
640647
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
641-
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
642-
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
648+
# 2) Stable phase at [learning_rate] for the majority of training
649+
# 3) Linear decay from [learning_rate] to [learning_rate * wsd_learning_rate_final_fraction] over [learning_rate_schedule_steps * wsd_decay_steps_fraction] steps
650+
# 4) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
651+
#
643652
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
644653
learning_rate: 3.e-5
645-
cosine_learning_rate_final_fraction: 0.1
646-
warmup_steps_fraction: 0.1
654+
lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd'
655+
cosine_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for cosine schedule
656+
wsd_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for WSD schedule
657+
wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%)
658+
warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules)
647659
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
648660
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
649661
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

src/MaxText/configs/types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ class OptimizerType(str, Enum):
124124
MUON = "muon"
125125

126126

127+
class LearningRateScheduleType(str, Enum):
128+
"""Supported learning rate schedule types."""
129+
130+
COSINE = "cosine"
131+
WSD = "wsd"
132+
133+
127134
class RopeType(str, Enum):
128135
"""Supported Rotary Positional Embedding (RoPE) implementations."""
129136

@@ -1005,9 +1012,18 @@ class Optimizer(BaseModel):
10051012
1.0, description="The threshold for gradient clipping. 0 disables clipping."
10061013
)
10071014
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
1015+
lr_schedule_type: LearningRateScheduleType = Field(
1016+
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
1017+
)
10081018
cosine_learning_rate_final_fraction: float = Field(
10091019
0.1, description="Final LR as a fraction of peak LR in cosine decay."
10101020
)
1021+
wsd_learning_rate_final_fraction: float = Field(
1022+
0.1, description="Final LR as a fraction of peak LR in WSD decay phase."
1023+
)
1024+
wsd_decay_steps_fraction: float = Field(
1025+
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
1026+
)
10111027
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
10121028
learning_rate_schedule_steps: int = Field(
10131029
-1,

src/MaxText/maxtext_utils.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,44 +1103,76 @@ def create_device_mesh(config, devices=None):
11031103

11041104

11051105
def create_learning_rate_schedule(config):
1106-
"""Creates a warmup and cosine decay learning rate schedule:
1107-
We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1108-
Learning rate schedule has either two or three parts:
1106+
"""Creates a learning rate schedule with warmup and decay.
1107+
1108+
Supports two schedule types:
1109+
- Cosine: Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1110+
- WSD (Warmup-Stable-Decay): Maintains constant learning rate for most of training before final decay
1111+
1112+
Schedule structure:
11091113
1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
1110-
2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
1114+
2) Decay from [learning_rate] to a final value until learning_rate_schedule_steps
1115+
- Cosine: decays to [learning_rate * cosine_learning_rate_final_fraction]
1116+
- WSD: maintains [learning_rate] for a stable phase, then linearly decays to [learning_rate * wsd_learning_rate_final_fraction]
11111117
3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
11121118
The zero learning rate section can be used to more accurately measure the fully trained model's performance.
11131119
"""
11141120

1115-
def make_cos_schedule(init_lr, final_lr, len_steps):
1116-
def schedule(step):
1117-
pct = (step) / len_steps
1118-
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
1119-
lr = init_lr * a + final_lr * (1 - a)
1120-
return lr
1121-
1122-
return schedule
1123-
11241121
lr = config.learning_rate
1125-
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1126-
11271122
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
1128-
cos_steps = config.learning_rate_schedule_steps - warmup_steps
11291123
constant_zero_steps = config.steps - config.learning_rate_schedule_steps
11301124

11311125
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
1132-
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1133-
constant_schedule = optax.constant_schedule(0.0)
11341126

1135-
pieces = [warmup_schedule, cos_schedule]
1136-
boundaries = [
1137-
warmup_steps,
1138-
warmup_steps + cos_steps,
1139-
]
1127+
if config.lr_schedule_type == "cosine":
1128+
1129+
def make_cos_schedule(init_lr, final_lr, len_steps):
1130+
def schedule(step):
1131+
pct = (step) / len_steps
1132+
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
1133+
lr = init_lr * a + final_lr * (1 - a)
1134+
return lr
1135+
1136+
return schedule
1137+
1138+
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1139+
cos_steps = config.learning_rate_schedule_steps - warmup_steps
1140+
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1141+
1142+
pieces = [warmup_schedule, cos_schedule]
1143+
boundaries = [warmup_steps, warmup_steps + cos_steps]
1144+
1145+
elif config.lr_schedule_type == "wsd":
1146+
wsd_final_lr = lr * config.wsd_learning_rate_final_fraction
1147+
decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
1148+
stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
1149+
1150+
if stable_steps < 0:
1151+
raise ValueError(
1152+
f"Invalid WSD schedule: warmup_steps_fraction ({config.warmup_steps_fraction}) + "
1153+
f"wsd_decay_steps_fraction ({config.wsd_decay_steps_fraction}) must not exceed 1.0. "
1154+
f"Current sum: {config.warmup_steps_fraction + config.wsd_decay_steps_fraction}"
1155+
)
1156+
1157+
stable_schedule = optax.constant_schedule(lr)
1158+
decay_schedule = optax.linear_schedule(init_value=lr, end_value=wsd_final_lr, transition_steps=decay_steps)
1159+
1160+
pieces = [warmup_schedule, stable_schedule, decay_schedule]
1161+
boundaries = [warmup_steps, warmup_steps + stable_steps]
1162+
1163+
max_logging.log(
1164+
f"WSD Learning Rate Schedule: warmup_steps={warmup_steps}, "
1165+
f"stable_steps={stable_steps}, decay_steps={decay_steps}, "
1166+
f"peak_lr={lr:.2e}, final_lr={wsd_final_lr:.2e}"
1167+
)
1168+
1169+
else:
1170+
raise ValueError(f"Invalid lr_schedule_type: {config.lr_schedule_type}. " "Must be either 'cosine' or 'wsd'.")
11401171

11411172
if constant_zero_steps > 0:
1173+
constant_schedule = optax.constant_schedule(0.0)
11421174
pieces.append(constant_schedule)
1143-
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
1175+
boundaries.append(boundaries[-1] + constant_zero_steps)
11441176

11451177
return optax.join_schedules(pieces, boundaries)
11461178

tests/maxtext_utils_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,5 +682,77 @@ def test_bytes_from_pytree_empty_dict(self):
682682
self.assertEqual(max_utils.calculate_bytes_from_pytree({}), 0)
683683

684684

685+
class TestLearningRateSchedules(unittest.TestCase):
686+
"""Test suite for learning rate schedule functions."""
687+
688+
def setUp(self):
689+
"""Set up common configuration for scheduler tests."""
690+
self.config = pyconfig.initialize(
691+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False
692+
)
693+
self.config.learning_rate = 1e-3
694+
self.config.learning_rate_schedule_steps = 1000
695+
self.config.steps = 1200
696+
self.config.warmup_steps_fraction = 0.1
697+
698+
def test_cosine_schedule(self):
699+
"""Tests cosine learning rate schedule."""
700+
self.config.lr_schedule_type = "cosine"
701+
self.config.cosine_learning_rate_final_fraction = 0.1
702+
703+
schedule_fn = maxtext_utils.create_learning_rate_schedule(self.config)
704+
warmup_steps = int(self.config.learning_rate_schedule_steps * self.config.warmup_steps_fraction)
705+
706+
# Warmup phase: 0 -> peak
707+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
708+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), self.config.learning_rate, places=6)
709+
710+
# Cosine decay phase
711+
lr_end = schedule_fn(self.config.learning_rate_schedule_steps - 1)
712+
expected_final = self.config.learning_rate * self.config.cosine_learning_rate_final_fraction
713+
self.assertLess(float(lr_end), self.config.learning_rate)
714+
self.assertGreater(float(lr_end), expected_final * 0.9)
715+
716+
# Zero phase
717+
self.assertAlmostEqual(float(schedule_fn(self.config.steps - 1)), 0.0, places=6)
718+
719+
def test_wsd_schedule(self):
720+
"""Tests WSD learning rate schedule."""
721+
self.config.lr_schedule_type = "wsd"
722+
self.config.wsd_learning_rate_final_fraction = 0.1
723+
self.config.wsd_decay_steps_fraction = 0.1
724+
725+
schedule_fn = maxtext_utils.create_learning_rate_schedule(self.config)
726+
727+
warmup_steps = int(self.config.learning_rate_schedule_steps * self.config.warmup_steps_fraction)
728+
decay_steps = int(self.config.learning_rate_schedule_steps * self.config.wsd_decay_steps_fraction)
729+
stable_steps = self.config.learning_rate_schedule_steps - warmup_steps - decay_steps
730+
decay_start = warmup_steps + stable_steps
731+
732+
# Warmup phase: 0 -> peak
733+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
734+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), self.config.learning_rate, places=6)
735+
736+
# Stable phase: constant at peak
737+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), self.config.learning_rate, places=6)
738+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + stable_steps // 2)), self.config.learning_rate, places=6)
739+
self.assertAlmostEqual(float(schedule_fn(decay_start - 1)), self.config.learning_rate, places=6)
740+
741+
# Decay phase: peak -> final
742+
lr_mid_decay = schedule_fn(decay_start + decay_steps // 2)
743+
expected_final = self.config.learning_rate * self.config.wsd_learning_rate_final_fraction
744+
self.assertLess(float(lr_mid_decay), self.config.learning_rate)
745+
self.assertGreater(float(lr_mid_decay), expected_final)
746+
747+
# Zero phase
748+
self.assertAlmostEqual(float(schedule_fn(self.config.steps - 1)), 0.0, places=6)
749+
750+
# Test invalid fractions
751+
self.config.warmup_steps_fraction = 0.6
752+
self.config.wsd_decay_steps_fraction = 0.5 # Sum > 1.0
753+
with self.assertRaises(ValueError):
754+
maxtext_utils.create_learning_rate_schedule(self.config)
755+
756+
685757
if __name__ == "__main__":
686758
unittest.main()

0 commit comments

Comments
 (0)