def _save_optimizer_and_scheduler()

in optimum/habana/transformers/trainer.py [0:0]


    def _save_optimizer_and_scheduler(self, output_dir):
        if self.is_deepspeed_enabled:
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
            # config `stage3_gather_16bit_weights_on_model_save` is True
            accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
                inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
            )
            if accept_exclude_frozen_parameters and _is_peft_model(self.model):
                self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
            else:
                self.model_wrapped.save_checkpoint(output_dir)
        elif self.is_fsdp_enabled:
            if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
                # TODO: for some reason the fsdp model is not unwrapped correctly here, the self.mode
                # shouldn't be an OptimizedModule at this point.
                model = self.model._orig_mod
            else:
                model = self.model
            # save fsdp specific ckpt for resuming from ckpt
            save_fsdp_model(
                self.accelerator.state.fsdp_plugin, self.accelerator, model, output_dir, **_get_fsdp_ckpt_kwargs()
            )
            save_fsdp_optimizer(
                self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, model, output_dir
            )
        elif self.args.should_save:
            # deepspeed.save_checkpoint above saves model/optim/sched
            # This block is executed by the main process only
            optim_dict = self.optimizer.state_dict()
            if self.args.use_habana:
                # Move the state dict from HPU to CPU before saving
                optim_dict = to_device_dtype(optim_dict, target_device=torch.device("cpu"))
            torch.save(optim_dict, os.path.join(output_dir, OPTIMIZER_NAME))

        # Save SCHEDULER & SCALER
        is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
            self.lr_scheduler, DeepSpeedSchedulerWrapper
        )
        if self.args.should_save and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler):
            if self.args.use_habana:
                # Move the state dict from HPU to CPU before saving
                scheduler_dict = self.lr_scheduler.state_dict()
                scheduler_dict = to_device_dtype(scheduler_dict, target_device=torch.device("cpu"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)