in src/transformers/trainer.py [0:0]
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
# if multiple adapters exist, they get saved in sub directories
adapter_subdirs = (
[
folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
and (
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
)
]
if os.path.isdir(resume_from_checkpoint)
else []
)
if is_fsdp_ckpt and not self.is_fsdp_enabled:
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
if not (
any(
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
)
or is_fsdp_ckpt
or adapter_subdirs
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}.")
if os.path.isfile(config_file):
config = PretrainedConfig.from_json_file(config_file)
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
# Checkpoint must have been saved with the new smp api.
smp.resume_from_checkpoint(
path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
)
else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api.
if hasattr(self.args, "fp16") and self.args.fp16 is True:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
)
check_torch_load_is_safe()
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
# release memory
del state_dict
elif self.is_fsdp_enabled:
load_fsdp_model(
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
resume_from_checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
check_torch_load_is_safe()
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
if os.path.exists(resume_from_checkpoint):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapters = model.active_adapters
if len(active_adapters) > 1:
logger.warning("Multiple active adapters detected will only consider the first adapter")
active_adapter = active_adapters[0]
else:
active_adapter = model.active_adapter
if adapter_subdirs:
for subdir_name in adapter_subdirs:
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
model.set_adapter(active_adapter)
else:
model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)