in optimum/neuron/trainers.py [0:0]
def _save_xla(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
if is_main_worker():
logger.info(f"Saving model checkpoint to {output_dir}")
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if is_main_worker():
logger.info(
"Model parallelism is enabled, saving the model sharded state dict instead of the full state dict."
)
model_to_save = self.model.original_torch_module if isinstance(self.model, NxDPPModel) else self.model
# This mark_step is needed to avoid hang issues.
xm.mark_step()
model_to_save.save_pretrained(
output_dir,
optimizer=self.optimizer if not self.args.save_only_model else None,
)
else:
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
if is_main_worker():
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)