def _load_optimizer_and_scheduler()

in optimum/habana/transformers/trainer.py [0:0]


    def _load_optimizer_and_scheduler(self, checkpoint):
        """If optimizer and scheduler states exist, load them."""
        if checkpoint is None:
            return

        if self.is_deepspeed_enabled:
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
            if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
                with warnings.catch_warnings(record=True) as caught_warnings:
                    self.lr_scheduler.load_state_dict(
                        torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
                    )
                reissue_pt_warnings(caught_warnings)
            return

        checkpoint_file_exists = (
            os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
            or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
            or (
                os.path.isdir(checkpoint)
                and any(
                    OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
                    for folder_name in os.listdir(checkpoint)
                    if os.path.isdir(os.path.join(checkpoint, folder_name))
                )
            )
        )

        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
            # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
            # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
            # likely to get OOM on CPU (since we load num_gpu times the optimizer state
            map_location = "cpu" if self.args.use_habana else self.args.device
            if self.is_fsdp_enabled:
                load_fsdp_optimizer(
                    self.accelerator.state.fsdp_plugin,
                    self.accelerator,
                    self.optimizer,
                    self.model,
                    checkpoint,
                    **_get_fsdp_ckpt_kwargs(),
                )
            else:
                self.optimizer.load_state_dict(
                    torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True)
                )

            with warnings.catch_warnings(record=True) as caught_warnings:
                self.lr_scheduler.load_state_dict(
                    torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location=map_location, weights_only=True)
                )
            reissue_pt_warnings(caught_warnings)

            # Move optimizer state to HPU
            if self.args.use_habana:
                to_device_dtype(self.optimizer.state.values(), target_device=torch.device("hpu"))