in optimum/habana/transformers/trainer.py [0:0]
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
model = self.model
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped,
self.state.best_model_checkpoint,
load_module_strict=not _is_peft_model(self.model),
)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
self.state.best_model_checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
elif (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapter = model.active_adapters[0]
if len(model.active_adapters) > 1:
logger.warning("Detected multiple active adapters, will only consider the first one")
else:
active_adapter = model.active_adapter
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
try:
model.load_adapter(self.state.best_model_checkpoint, active_adapter)
except RuntimeError as exc:
if model.peft_config[active_adapter].is_prompt_learning:
# for context: https://github.com/huggingface/peft/issues/2256
msg = (
"When using prompt learning PEFT methods such as "
f"{model.peft_config[active_adapter].peft_type.value}, setting "
"load_best_model_at_end=True can lead to errors, it is recommended "
"to set this to False and to load the model manually from the checkpoint "
"directory using PeftModel.from_pretrained(base_model, <path>) after training "
"has finished."
)
raise RuntimeError(msg) from exc
else:
raise
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
load_result = _IncompatibleKeys([], [])
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
):
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)