in train/model.py [0:0]
def load_state(self, state_dict, strict=False):
if strict:
self.net.load_state_dict(state_dict=state_dict)
else:
# customized partialy load function
net_state_keys = list(self.net.state_dict().keys())
for name, param in state_dict.items():
if name in self.net.state_dict().keys():
dst_param_shape = self.net.state_dict()[name].shape
if param.shape == dst_param_shape:
self.net.state_dict()[name].copy_(param.view(dst_param_shape))
net_state_keys.remove(name)
# indicating missed keys
if net_state_keys:
logging.warning(">> Failed to load: {}".format(net_state_keys))
return False
return True