You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/MaxText/configs/base.yml
+19-7Lines changed: 19 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -607,7 +607,7 @@ grain_file_type: 'arrayrecord' # arrayrecord or parquet
607
607
grain_packing_type: 'first_fit'# 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
608
608
grain_worker_count: 1# Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
609
609
grain_per_worker_buffer_size: 1
610
-
# num_threads and prefetch_buffer_size are per-worker per-dataset.
610
+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
611
611
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
612
612
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
613
613
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
@@ -635,15 +635,27 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
635
635
# However when run on google internal TPUs the coordination service is started automatically
636
636
# and we should set this to True so we won't try to initialize a second time manually.
637
637
638
-
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
639
-
# Learning rate schedule has either two or three parts:
638
+
# Learning rate schedule structure depends on lr_schedule_type:
639
+
#
640
+
# Cosine schedule (lr_schedule_type='cosine'):
641
+
# Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
642
+
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
643
+
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
644
+
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
641
-
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
642
-
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
648
+
# 2) Stable phase at [learning_rate] for the majority of training
649
+
# 3) Linear decay from [learning_rate] to [learning_rate * wsd_learning_rate_final_fraction] over [learning_rate_schedule_steps * wsd_decay_steps_fraction] steps
650
+
# 4) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
651
+
#
643
652
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
644
653
learning_rate: 3.e-5
645
-
cosine_learning_rate_final_fraction: 0.1
646
-
warmup_steps_fraction: 0.1
654
+
lr_schedule_type: 'cosine'# Options: 'cosine' or 'wsd'
655
+
cosine_learning_rate_final_fraction: 0.1# Final LR as fraction of peak LR for cosine schedule
656
+
wsd_learning_rate_final_fraction: 0.1# Final LR as fraction of peak LR for WSD schedule
657
+
wsd_decay_steps_fraction: 0.1# Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%)
658
+
warmup_steps_fraction: 0.1# Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules)
647
659
learning_rate_schedule_steps: -1# By default the length of the schedule is set to the number of steps.
648
660
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
649
661
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
0 commit comments