def _load_state_dict_into_model()

in optimum/neuron/models/training/modeling_utils.py [0:0]


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module: nn.Module, state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # Parameters of module and children will start with prefix. We can exit early if there are none in this
        # state_dict
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
            # ** Difference from original _load_state_dict_into_model **
            # We do not add the code related to `deepspeed` here, since we do not support it.

            # ** Difference from original _load_state_dict_into_model **
            # module._load_from_state_dict can mutate the parameters in the module, we must cache the tensor parallel
            # metadata.
            tensor_model_parallel_attributes = {
                k: get_tensor_model_parallel_attributes(v) for k, v in module._parameters.items()
            }

            module._load_from_state_dict(*args)

            # Restoring the tensor model parallel attributes.
            for name, param in module._parameters.items():
                attributes = tensor_model_parallel_attributes[name]
                for attr_name, attr in attributes.items():
                    setattr(param, attr_name, attr)

        for name, child in module._modules.items():
            if child is not None:
                load(child, state_dict, prefix + name + ".")

    load(model_to_load, state_dict, prefix=start_prefix)
    # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
    # it's safe to delete it.
    del state_dict

    return error_msgs