in optimum/habana/transformers/trainer.py [0:0]
def _save_optimizer_and_scheduler(self, output_dir):
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
# TODO: for some reason the fsdp model is not unwrapped correctly here, the self.mode
# shouldn't be an OptimizedModule at this point.
model = self.model._orig_mod
else:
model = self.model
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, output_dir, **_get_fsdp_ckpt_kwargs()
)
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, model, output_dir
)
elif self.args.should_save:
# deepspeed.save_checkpoint above saves model/optim/sched
# This block is executed by the main process only
optim_dict = self.optimizer.state_dict()
if self.args.use_habana:
# Move the state dict from HPU to CPU before saving
optim_dict = to_device_dtype(optim_dict, target_device=torch.device("cpu"))
torch.save(optim_dict, os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if self.args.should_save and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler):
if self.args.use_habana:
# Move the state dict from HPU to CPU before saving
scheduler_dict = self.lr_scheduler.state_dict()
scheduler_dict = to_device_dtype(scheduler_dict, target_device=torch.device("cpu"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)