in src/accelerate/utils/modeling.py [0:0]
def load_state_dict(checkpoint_file, device_map=None):
"""
Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
weights can be fast-loaded directly on the GPU.
Args:
checkpoint_file (`str`): The path to the checkpoint to load.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
"""
if checkpoint_file.endswith(".safetensors"):
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
if metadata is None:
logger.warning(
f"The safetensors archive passed at {checkpoint_file} does not contain metadata. "
"Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata."
)
metadata = {"format": "pt"}
if metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
elif metadata["format"] != "pt":
raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
if device_map is None:
return safe_load_file(checkpoint_file)
else:
# if we only have one device we can load everything directly
if len(set(device_map.values())) == 1:
device = list(device_map.values())[0]
target_device = device
if isinstance(device, int):
if is_npu_available():
target_device = f"npu:{device}"
elif is_hpu_available():
target_device = "hpu"
return safe_load_file(checkpoint_file, device=target_device)
devices = list(set(device_map.values()) - {"disk"})
# cpu device should always exist as fallback option
if "cpu" not in devices:
devices.append("cpu")
# For each device, get the weights that go there
device_weights = {device: [] for device in devices}
for module_name, device in device_map.items():
if device in devices:
device_weights[device].extend(
[k for k in weight_names if k == module_name or k.startswith(module_name + ".")]
)
# all weights that haven't defined a device should be loaded on CPU
device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])
tensors = {}
if is_tqdm_available():
progress_bar = tqdm(
main_process_only=False,
total=sum([len(device_weights[device]) for device in devices]),
unit="w",
smoothing=0,
leave=False,
)
else:
progress_bar = None
for device in devices:
target_device = device
if isinstance(device, int):
if is_npu_available():
target_device = f"npu:{device}"
elif is_hpu_available():
target_device = "hpu"
with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
for key in device_weights[device]:
if progress_bar is not None:
progress_bar.set_postfix(dev=device, refresh=False)
progress_bar.set_description(key)
tensors[key] = f.get_tensor(key)
if progress_bar is not None:
progress_bar.update()
if progress_bar is not None:
progress_bar.close()
return tensors
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)