def custom_load_state_dict()

in src/nanotron/optim/base.py [0:0]


def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None:
    r"""Loads the optimizer state.

    Args:
        state_dict (dict): optimizer state. Should be an object returned
            from a call to :meth:`state_dict`.
        map_location (str or torch.device, optional): Device where to load the optimizer states.
            If None, states will be loaded to the same device as their corresponding parameters.
            Default: None
    """

    # shallow copy, to be consistent with module API
    state_dict = state_dict.copy()

    for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
        hook_result = pre_hook(self, state_dict)
        if hook_result is not None:
            state_dict = hook_result

    # Validate the state_dict
    groups = self.param_groups

    # Deepcopy as we write into saved_groups later to update state
    saved_groups = deepcopy(state_dict["param_groups"])

    if len(groups) != len(saved_groups):
        raise ValueError("loaded state dict has a different number of " "parameter groups")
    param_lens = (len(g["params"]) for g in groups)
    saved_lens = (len(g["params"]) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError(
            "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
        )

    # Update the state
    id_map = dict(
        zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
    )

    def _cast(param, value, param_id=None, param_groups=None, key=None):
        r"""Make a deep copy of value, casting all tensors to device of param."""
        if isinstance(value, torch.Tensor):
            return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
        elif isinstance(value, dict):
            return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
        elif isinstance(value, Iterable):
            return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
        else:
            return value

    # Copy state assigned to params (and cast tensors to appropriate types).
    # State that is not assigned to params is copied as is (needed for
    # backward compatibility).
    state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
    for k, v in state_dict["state"].items():
        if k in id_map:
            param = id_map[k]
            state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
        else:
            state[k] = v

    # Update parameter groups, setting their 'params' value
    def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
        new_group["params"] = group["params"]
        return new_group

    param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
    self.__setstate__({"state": state, "param_groups": param_groups})

    for post_hook in self._optimizer_load_state_dict_post_hooks.values():
        post_hook(self)