in optimum/neuron/models/training/modeling_utils.py [0:0]
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
# ** Difference from original _load_state_dict_into_model **
# We do not add the code related to `deepspeed` here, since we do not support it.
# ** Difference from original _load_state_dict_into_model **
# module._load_from_state_dict can mutate the parameters in the module, we must cache the tensor parallel
# metadata.
tensor_model_parallel_attributes = {
k: get_tensor_model_parallel_attributes(v) for k, v in module._parameters.items()
}
module._load_from_state_dict(*args)
# Restoring the tensor model parallel attributes.
for name, param in module._parameters.items():
attributes = tensor_model_parallel_attributes[name]
for attr_name, attr in attributes.items():
setattr(param, attr_name, attr)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs