Skip to content

Commit 76e78fd

Browse files
graph support eager lrs (#6262)
* add multistep lr, refine * add steplr and consine annealing lr for graph Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
1 parent 74775c1 commit 76e78fd

20 files changed

+287
-12
lines changed

oneflow/core/job/learning_rate_schedule_conf.proto

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ message CosineDecayConf {
3636
optional double alpha = 2 [default = 0.0];
3737
}
3838

39+
message CosineAnnealingDecayConf {
40+
required int64 t_max = 1;
41+
optional double eta_min = 2 [default = 0.0];
42+
}
43+
3944
message LinearCosineDecayConf {
4045
required int64 decay_batches = 1;
4146
optional double num_periods = 2 [default = 0.5];
@@ -48,6 +53,16 @@ message PiecewiseScalingConf {
4853
repeated double scales = 2;
4954
}
5055

56+
message StepConf {
57+
required int64 step_size = 1;
58+
optional double gamma = 2 [default = 0.1];
59+
}
60+
61+
message MultiStepConf {
62+
repeated int64 milestones = 1;
63+
optional double gamma = 2 [default = 0.1];
64+
}
65+
5166
message LearningRateDecayConf {
5267
oneof type {
5368
ExponentialDecayConf exponential_conf = 2000;
@@ -58,6 +73,9 @@ message LearningRateDecayConf {
5873
CosineDecayConf cosine_conf = 2005;
5974
LinearCosineDecayConf linear_cosine_conf = 2006;
6075
PiecewiseScalingConf piecewise_scaling_conf = 2007;
76+
MultiStepConf multi_step_conf = 2008;
77+
StepConf step_conf = 2009;
78+
CosineAnnealingDecayConf cosine_annealing_conf = 2010;
6179
}
6280
}
6381

oneflow/core/kernel/learning_rate_schedule_kernel.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,20 @@ double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t
148148
return lr * decayed;
149149
}
150150

151+
double CosineAnnealingDecayedLearningRate(const CosineAnnealingDecayConf& conf, double lr,
152+
int64_t cur_batch_num) {
153+
CHECK_GT(conf.t_max(), 0);
154+
if (0 == cur_batch_num) { return lr; }
155+
156+
const double PI = std::atan(1.0) * 4.0;
157+
const double eta_min = conf.eta_min();
158+
CHECK_LT(eta_min, lr);
159+
const double t_max_d = static_cast<double>(conf.t_max());
160+
const double cur_batch_num_d = static_cast<double>(cur_batch_num);
161+
162+
return eta_min + (((lr - eta_min) * (1 + std::cos(PI * (cur_batch_num_d / t_max_d)))) / 2);
163+
}
164+
151165
double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr,
152166
int64_t cur_batch_num) {
153167
CHECK_GT(conf.decay_batches(), 0);
@@ -174,6 +188,35 @@ double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr,
174188
return scales[i] * lr;
175189
}
176190

191+
double StepLearningRate(const StepConf& conf, double lr, int64_t cur_batch_num) {
192+
const int64_t step_size = conf.step_size();
193+
CHECK_GE(step_size, 1);
194+
const double gamma = conf.gamma();
195+
196+
double cur_batch = static_cast<double>(cur_batch_num);
197+
double step = static_cast<double>(step_size);
198+
size_t i = std::floor(cur_batch / step);
199+
200+
return lr * std::pow(gamma, i);
201+
}
202+
203+
double MultiStepLearningRate(const MultiStepConf& conf, double lr, int64_t cur_batch_num) {
204+
const PbRf<int64_t>& milestones = conf.milestones();
205+
CHECK_GE(milestones.size(), 1);
206+
const double gamma = conf.gamma();
207+
208+
size_t i = 0;
209+
if (cur_batch_num < milestones[milestones.size() - 1]) {
210+
for (; i < milestones.size(); ++i) {
211+
if (cur_batch_num < milestones[i]) { break; }
212+
}
213+
} else {
214+
i = milestones.size();
215+
}
216+
217+
return lr * std::pow(gamma, i);
218+
}
219+
177220
double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {
178221
if (conf.has_exponential_conf()) {
179222
return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);
@@ -187,10 +230,16 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6
187230
return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num);
188231
} else if (conf.has_cosine_conf()) {
189232
return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num);
233+
} else if (conf.has_cosine_annealing_conf()) {
234+
return CosineAnnealingDecayedLearningRate(conf.cosine_annealing_conf(), lr, cur_batch_num);
190235
} else if (conf.has_linear_cosine_conf()) {
191236
return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num);
192237
} else if (conf.has_piecewise_scaling_conf()) {
193238
return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num);
239+
} else if (conf.has_step_conf()) {
240+
return StepLearningRate(conf.step_conf(), lr, cur_batch_num);
241+
} else if (conf.has_multi_step_conf()) {
242+
return MultiStepLearningRate(conf.multi_step_conf(), lr, cur_batch_num);
194243
} else {
195244
UNIMPLEMENTED();
196245
}

python/oneflow/amp/grad_scaler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
)
3434
self._growth_interval = growth_interval
3535

36-
def generate_conf_for_graph(self, train_conf):
36+
def _generate_conf_for_graph(self, train_conf):
3737
train_conf.mutable_dynamic_loss_scale_policy().set_initial_loss_scale(
3838
self._init_scale
3939
)
@@ -52,5 +52,5 @@ def __init__(self, scale_factor):
5252

5353
self._scale_factor = scale_factor
5454

55-
def generate_conf_for_graph(self, train_conf):
55+
def _generate_conf_for_graph(self, train_conf):
5656
train_conf.set_loss_scale_factor(self._scale_factor)

python/oneflow/nn/graph/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _generate_config_proto(self):
386386
self.config.proto.set_job_name(self._name)
387387

388388
if self._grad_scaler is not None:
389-
self._grad_scaler.generate_conf_for_graph(
389+
self._grad_scaler._generate_conf_for_graph(
390390
self.config.proto.mutable_train_conf()
391391
)
392392

python/oneflow/nn/graph/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def __init__(
3636

3737
def generate_optimizer_and_variable_configs(self, train_conf, vars_conf):
3838
if self._optimizer is not None:
39-
opt_confs = self._optimizer.generate_conf_for_graph(train_conf, vars_conf)
39+
opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf)
4040
if self._lr_scheduler is not None:
41-
self._lr_scheduler.generate_conf_for_graph(opt_confs)
41+
self._lr_scheduler._generate_conf_for_graph(opt_confs)
4242

4343

4444
class VariableConfig(object):

python/oneflow/nn/optimizer/adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def step(self, closure: Callable = None):
207207

208208
return loss
209209

210-
def generate_conf_for_graph(self, train_conf, vars_conf):
210+
def _generate_conf_for_graph(self, train_conf, vars_conf):
211211
new_opt_confs = []
212212
for param_group in self.param_groups:
213213
optimizer_conf = train_conf.mutable_optimizer_conf().Add()

python/oneflow/nn/optimizer/adamw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def step(self, closure: Callable = None):
209209
self._state["step"] += 1
210210
return loss
211211

212-
def generate_conf_for_graph(self, train_conf, vars_conf):
212+
def _generate_conf_for_graph(self, train_conf, vars_conf):
213213
new_opt_confs = []
214214
for param_group in self.param_groups:
215215
optimizer_conf = train_conf.mutable_optimizer_conf().Add()

python/oneflow/nn/optimizer/cosine_annealing_lr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@ def get_lr(self):
8484
+ self.eta_min
8585
for group in self._optimizer.param_groups
8686
]
87+
88+
def _generate_conf_for_graph(self, opt_confs):
89+
for opt_conf in opt_confs:
90+
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
91+
learning_rate_decay_conf.mutable_cosine_annealing_conf().set_t_max(
92+
self.T_max
93+
)
94+
learning_rate_decay_conf.mutable_cosine_annealing_conf().set_eta_min(
95+
self.eta_min
96+
)

python/oneflow/nn/optimizer/cosine_decay_lr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_lr(self):
8989
else:
9090
return [base_lr * self.alpha for base_lr in self.base_lrs]
9191

92-
def generate_conf_for_graph(self, opt_confs):
92+
def _generate_conf_for_graph(self, opt_confs):
9393
# CosineDecayLR is the same as CosineDecayConf in nn.Graph
9494
for opt_conf in opt_confs:
9595
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()

python/oneflow/nn/optimizer/multistep_lr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,12 @@ def get_lr(self):
6565
return [group["lr"] for group in self._optimizer.param_groups]
6666
else:
6767
return [group["lr"] * self.gamma for group in self._optimizer.param_groups]
68+
69+
def _generate_conf_for_graph(self, opt_confs):
70+
for opt_conf in opt_confs:
71+
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
72+
for milestone in self.milestones:
73+
learning_rate_decay_conf.mutable_multi_step_conf().add_milestones(
74+
milestone
75+
)
76+
learning_rate_decay_conf.mutable_multi_step_conf().set_gamma(self.gamma)

0 commit comments

Comments
 (0)