in lmgvp/transfer.py [0:0]
def load_state_dict_to_model(model, state_dict):
"""Initialize a model with parameters in `state_dict` (inplace)
from a pretrained model with slightly different architecture.
Args:
model: Torch model
state_dict: Dictionary containing weight for each layer of the `model`
Returns:
input `model` where layer weights have been updated based on `state_dict`
"""
own_state = model.state_dict()
print("model own state keys:", len(own_state))
print("state_dict keys:", len(state_dict))
keys_loaded = 0
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
keys_loaded += 1
print("keys loaded into model:", keys_loaded)