def load_state_dict()

in src/accelerate/utils/modeling.py [0:0]


def load_state_dict(checkpoint_file, device_map=None):
    """
    Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
    weights can be fast-loaded directly on the GPU.

    Args:
        checkpoint_file (`str`): The path to the checkpoint to load.
        device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
            A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
            name, once a given module name is inside, every submodule of it will be sent to the same device.
    """
    if checkpoint_file.endswith(".safetensors"):
        with safe_open(checkpoint_file, framework="pt") as f:
            metadata = f.metadata()
            weight_names = f.keys()

        if metadata is None:
            logger.warning(
                f"The safetensors archive passed at {checkpoint_file} does not contain metadata. "
                "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata."
            )
            metadata = {"format": "pt"}

        if metadata.get("format") not in ["pt", "tf", "flax"]:
            raise OSError(
                f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
                "you save your model with the `save_pretrained` method."
            )
        elif metadata["format"] != "pt":
            raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
        if device_map is None:
            return safe_load_file(checkpoint_file)
        else:
            # if we only have one device we can load everything directly
            if len(set(device_map.values())) == 1:
                device = list(device_map.values())[0]
                target_device = device
                if isinstance(device, int):
                    if is_npu_available():
                        target_device = f"npu:{device}"
                    elif is_hpu_available():
                        target_device = "hpu"

                return safe_load_file(checkpoint_file, device=target_device)

            devices = list(set(device_map.values()) - {"disk"})
            # cpu device should always exist as fallback option
            if "cpu" not in devices:
                devices.append("cpu")

            # For each device, get the weights that go there
            device_weights = {device: [] for device in devices}
            for module_name, device in device_map.items():
                if device in devices:
                    device_weights[device].extend(
                        [k for k in weight_names if k == module_name or k.startswith(module_name + ".")]
                    )

            # all weights that haven't defined a device should be loaded on CPU
            device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])
            tensors = {}
            if is_tqdm_available():
                progress_bar = tqdm(
                    main_process_only=False,
                    total=sum([len(device_weights[device]) for device in devices]),
                    unit="w",
                    smoothing=0,
                    leave=False,
                )
            else:
                progress_bar = None
            for device in devices:
                target_device = device
                if isinstance(device, int):
                    if is_npu_available():
                        target_device = f"npu:{device}"
                    elif is_hpu_available():
                        target_device = "hpu"

                with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
                    for key in device_weights[device]:
                        if progress_bar is not None:
                            progress_bar.set_postfix(dev=device, refresh=False)
                            progress_bar.set_description(key)
                        tensors[key] = f.get_tensor(key)
                        if progress_bar is not None:
                            progress_bar.update()
            if progress_bar is not None:
                progress_bar.close()

            return tensors
    else:
        return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)