in optimum/habana/transformers/trainer.py [0:0]
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
if state_dict is None:
state_dict = self.model.state_dict()
if state_dict and self.args.use_habana:
# state_dict items have to be saved on the CPU
state_dict = to_device_dtype(state_dict, target_device=torch.device("cpu"))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
self.data_collator.tokenizer.save_pretrained(output_dir)
self.gaudi_config.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))