in optimum/intel/neural_compressor/trainer.py [0:0]
def _save(self, output_dir=None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
if os.path.isfile(output_dir):
logger.error(f"Provided path ({output_dir}) should be a directory, not a file")
return
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
# Save the config
if self.model.can_generate():
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.model.config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(self.model.generation_config, param_name, param_value)
setattr(self.model.config, param_name, None)
self.model.generation_config.save_pretrained(output_dir)
if self.model.config is not None:
self.model.config.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Save the model
if state_dict is None:
state_dict = self.model.state_dict()
if self._compression_manager is not None and hasattr(self._compression_manager.model, "q_config"):
state_dict["best_configure"] = self._compression_manager.model.q_config
torch.save(state_dict, output_model_file)
if self.pruning_config is not None:
self.inc_config.pruning["sparsity"] = round(self.get_model_sparsity(), 2)
self.inc_config.save_pretrained(output_dir)
logger.info(f"Model weights saved in {output_model_file}")