in lib/net_util.py [0:0]
def load_state_dict(state_dict, net):
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
not_initialized = set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print('not initialized', sorted(not_initialized))
net.load_state_dict(model_dict)
return net