@@ -160,9 +160,11 @@ def load_checkpoint(
160160 raise FileNotFoundError (f"{ path } .scaler not exist." )
161161
162162 # load state dict
163- param_dict = paddle .load (f"{ path } .pdparams" )
163+ model_dict = paddle .load (f"{ path } .pdparams" )
164164 optim_dict = paddle .load (f"{ path } .pdopt" )
165- metric_dict = paddle .load (f"{ path } .pdstates" )
165+ metric_dict = {}
166+ if os .path .exists (f"{ path } .pdstates" ):
167+ metric_dict = paddle .load (f"{ path } .pdstates" )
166168 if grad_scaler is not None :
167169 scaler_dict = paddle .load (f"{ path } .pdscaler" )
168170 if equation is not None :
@@ -172,9 +174,9 @@ def load_checkpoint(
172174 else :
173175 equation_dict = paddle .load (f"{ path } .pdeqn" )
174176
175- # set state dict
177+ # set model state dict
176178 logger .message (f"* Loading model checkpoint from { path } .pdparams" )
177- missing_keys , unexpected_keys = model .set_state_dict (param_dict )
179+ missing_keys , unexpected_keys = model .set_state_dict (model_dict )
178180 if missing_keys :
179181 logger .warning (
180182 f"There are missing keys when loading checkpoint: { missing_keys } , "
@@ -186,20 +188,23 @@ def load_checkpoint(
186188 "and corresponding weights will be ignored."
187189 )
188190
191+ # set optimizer state dict
189192 logger .message (f"* Loading optimizer checkpoint from { path } .pdopt" )
190193 optimizer .set_state_dict (optim_dict )
194+
191195 if grad_scaler is not None :
192196 logger .message (f"* Loading grad scaler checkpoint from { path } .pdscaler" )
193197 grad_scaler .load_state_dict (scaler_dict )
198+
194199 if equation is not None and equation_dict is not None :
195200 logger .message (f"* Loading equation checkpoint from { path } .pdeqn" )
196201 for name , _equation in equation .items ():
197202 _equation .set_state_dict (equation_dict [name ])
198203
199- if ema_model :
204+ if ema_model is not None :
200205 logger .message (f"* Loading EMA checkpoint from { path } _ema.pdparams" )
201- avg_param_dict = paddle .load (f"{ path } _ema.pdparams" )
202- ema_model .set_state_dict (avg_param_dict )
206+ avg_model_dict = paddle .load (f"{ path } _ema.pdparams" )
207+ ema_model .set_state_dict (avg_model_dict )
203208
204209 if aggregator is not None and aggregator .should_persist :
205210 logger .message (f"* Loading loss aggregator checkpoint from { path } .pdagg" )
@@ -213,7 +218,7 @@ def load_checkpoint(
213218def save_checkpoint (
214219 model : nn .Layer ,
215220 optimizer : Optional [optimizer .Optimizer ],
216- metric : Dict [str , float ],
221+ metric : Optional [ Dict [str , float ]] = None ,
217222 grad_scaler : Optional [amp .GradScaler ] = None ,
218223 output_dir : Optional [str ] = None ,
219224 prefix : str = "model" ,
@@ -228,7 +233,7 @@ def save_checkpoint(
228233 Args:
229234 model (nn.Layer): Model with parameters.
230235 optimizer (Optional[optimizer.Optimizer]): Optimizer for model.
231- metric (Dict[str, float]): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}.
236+ metric (Optional[ Dict[str, float]] ): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}. Defaults to None .
232237 grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
233238 output_dir (Optional[str]): Directory for checkpoint storage.
234239 prefix (str, optional): Prefix for storage. Defaults to "model".
@@ -259,11 +264,16 @@ def save_checkpoint(
259264 os .makedirs (ckpt_dir , exist_ok = True )
260265
261266 paddle .save (model .state_dict (), f"{ ckpt_path } .pdparams" )
262- if optimizer :
267+
268+ if optimizer is not None :
263269 paddle .save (optimizer .state_dict (), f"{ ckpt_path } .pdopt" )
264- paddle .save (metric , f"{ ckpt_path } .pdstates" )
270+
271+ if metric is not None and len (metric ) > 0 :
272+ paddle .save (metric , f"{ ckpt_path } .pdstates" )
273+
265274 if grad_scaler is not None :
266275 paddle .save (grad_scaler .state_dict (), f"{ ckpt_path } .pdscaler" )
276+
267277 if equation is not None :
268278 num_learnable_params = sum (
269279 [len (eq .learnable_parameters ) for eq in equation .values ()]
@@ -274,10 +284,10 @@ def save_checkpoint(
274284 f"{ ckpt_path } .pdeqn" ,
275285 )
276286
277- if ema_model :
287+ if ema_model is not None :
278288 paddle .save (ema_model .state_dict (), f"{ ckpt_path } _ema.pdparams" )
279289
280- if aggregator and aggregator .should_persist :
290+ if aggregator is not None and aggregator .should_persist :
281291 paddle .save (aggregator .state_dict (), f"{ ckpt_path } .pdagg" )
282292
283293 if print_log :
0 commit comments