@@ -148,6 +148,20 @@ double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t
148148 return lr * decayed;
149149}
150150
151+ double CosineAnnealingDecayedLearningRate (const CosineAnnealingDecayConf& conf, double lr,
152+ int64_t cur_batch_num) {
153+ CHECK_GT (conf.t_max (), 0 );
154+ if (0 == cur_batch_num) { return lr; }
155+
156+ const double PI = std::atan (1.0 ) * 4.0 ;
157+ const double eta_min = conf.eta_min ();
158+ CHECK_LT (eta_min, lr);
159+ const double t_max_d = static_cast <double >(conf.t_max ());
160+ const double cur_batch_num_d = static_cast <double >(cur_batch_num);
161+
162+ return eta_min + (((lr - eta_min) * (1 + std::cos (PI * (cur_batch_num_d / t_max_d)))) / 2 );
163+ }
164+
151165double LinearCosineDecayedLearningRate (const LinearCosineDecayConf& conf, double lr,
152166 int64_t cur_batch_num) {
153167 CHECK_GT (conf.decay_batches (), 0 );
@@ -174,6 +188,35 @@ double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr,
174188 return scales[i] * lr;
175189}
176190
191+ double StepLearningRate (const StepConf& conf, double lr, int64_t cur_batch_num) {
192+ const int64_t step_size = conf.step_size ();
193+ CHECK_GE (step_size, 1 );
194+ const double gamma = conf.gamma ();
195+
196+ double cur_batch = static_cast <double >(cur_batch_num);
197+ double step = static_cast <double >(step_size);
198+ size_t i = std::floor (cur_batch / step);
199+
200+ return lr * std::pow (gamma, i);
201+ }
202+
203+ double MultiStepLearningRate (const MultiStepConf& conf, double lr, int64_t cur_batch_num) {
204+ const PbRf<int64_t >& milestones = conf.milestones ();
205+ CHECK_GE (milestones.size (), 1 );
206+ const double gamma = conf.gamma ();
207+
208+ size_t i = 0 ;
209+ if (cur_batch_num < milestones[milestones.size () - 1 ]) {
210+ for (; i < milestones.size (); ++i) {
211+ if (cur_batch_num < milestones[i]) { break ; }
212+ }
213+ } else {
214+ i = milestones.size ();
215+ }
216+
217+ return lr * std::pow (gamma, i);
218+ }
219+
177220double GetDecayedLearningRate (const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {
178221 if (conf.has_exponential_conf ()) {
179222 return ExponentialDecayedLearningRate (conf.exponential_conf (), lr, cur_batch_num);
@@ -187,10 +230,16 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6
187230 return PolynomialDecayedLearningRate (conf.polynomial_conf (), lr, cur_batch_num);
188231 } else if (conf.has_cosine_conf ()) {
189232 return CosineDecayedLearningRate (conf.cosine_conf (), lr, cur_batch_num);
233+ } else if (conf.has_cosine_annealing_conf ()) {
234+ return CosineAnnealingDecayedLearningRate (conf.cosine_annealing_conf (), lr, cur_batch_num);
190235 } else if (conf.has_linear_cosine_conf ()) {
191236 return LinearCosineDecayedLearningRate (conf.linear_cosine_conf (), lr, cur_batch_num);
192237 } else if (conf.has_piecewise_scaling_conf ()) {
193238 return PiecewiseScalingLearningRate (conf.piecewise_scaling_conf (), lr, cur_batch_num);
239+ } else if (conf.has_step_conf ()) {
240+ return StepLearningRate (conf.step_conf (), lr, cur_batch_num);
241+ } else if (conf.has_multi_step_conf ()) {
242+ return MultiStepLearningRate (conf.multi_step_conf (), lr, cur_batch_num);
194243 } else {
195244 UNIMPLEMENTED ();
196245 }
0 commit comments