def _zero3_load_from_state_dict()

in modules/SwissArmyTransformer/sat/training/model_io.py [0:0]


def _zero3_load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
    r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.

    This is called on every submodule
    in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
    module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
    For state dicts without metadata, :attr:`local_metadata` is empty.
    Subclasses can achieve class-specific backward compatible loading using
    the version number at `local_metadata.get("version", None)`.
    Additionally, :attr:`local_metadata` can also contain the key
    `assign_to_params_buffers` that indicates whether keys should be
    assigned their corresponding tensor in the state_dict.

    .. note::
        :attr:`state_dict` is not the same object as the input
        :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
        it can be modified.

    Args:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        prefix (str): the prefix for parameters and buffers used in this
            module
        local_metadata (dict): a dict containing the metadata for this module.
            See
        strict (bool): whether to strictly enforce that the keys in
            :attr:`state_dict` with :attr:`prefix` match the names of
            parameters and buffers in this module
        missing_keys (list of str): if ``strict=True``, add missing keys to
            this list
        unexpected_keys (list of str): if ``strict=True``, add unexpected
            keys to this list
        error_msgs (list of str): error messages should be added to this
            list, and will be reported together in
            :meth:`~torch.nn.Module.load_state_dict`
    """
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
    local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]
            if not torch.overrides.is_tensor_like(input_param):
                error_msgs.append(f'While copying the parameter named "{key}", '
                                    'expected torch.Tensor or Tensor-like object from checkpoint but '
                                    f'received {type(input_param)}'
                                    )
                continue

            # This is used to avoid copying uninitialized parameters into
            # non-lazy modules, since they dont have the hook to do the checks
            # in such case, it will error when accessing the .shape attribute.
            is_param_lazy = torch.nn.parameter.is_lazy(param)
            # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
            if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]

            if not is_param_lazy and input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                    'the shape in current model is {}.'
                                    .format(key, input_param.shape, param.shape))
                continue

            if param.is_meta and not input_param.is_meta and not assign_to_params_buffers:
                warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '
                                'parameter in the current model, which is a no-op. (Did you mean to '
                                'pass `assign=True` to assign items in the state dictionary to their '
                                'corresponding key in the module instead of copying them in place?)')

            try:
                with torch.no_grad():
                    if assign_to_params_buffers:
                        # Shape checks are already done above
                        if (isinstance(param, torch.nn.Parameter) and
                                not isinstance(input_param, torch.nn.Parameter)):
                            setattr(self, name, torch.nn.Parameter(input_param))
                        else:
                            setattr(self, name, input_param)
                    else:
                        if hasattr(param, 'ds_id'):
                            # deepspeed param for zero3
                            param.ds_tensor.copy_(input_param)
                        else:
                            param.copy_(input_param)
            except Exception as ex:
                error_msgs.append(f'While copying the parameter named "{key}", '
                                    f'whose dimensions in the model are {param.size()} and '
                                    f'whose dimensions in the checkpoint are {input_param.size()}, '
                                    f'an exception occurred : {ex.args}.'
                                    )
        elif strict:
            missing_keys.append(key)

    extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
    if getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
        if extra_state_key in state_dict:
            self.set_extra_state(state_dict[extra_state_key])
        elif strict:
            missing_keys.append(extra_state_key)
    elif strict and (extra_state_key in state_dict):
        unexpected_keys.append(extra_state_key)

    if strict:
        for key in state_dict.keys():
            if key.startswith(prefix) and key != extra_state_key:
                input_name = key[len(prefix):]
                input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules and input_name not in local_state:
                    unexpected_keys.append(key)