in pyhanabi/utils.py [0:0]
def load_weight(model, weight_file, device):
state_dict = torch.load(weight_file, map_location=device)
source_state_dict = OrderedDict()
target_state_dict = model.state_dict()
for k, v in target_state_dict.items():
if k not in state_dict:
print("warning: %s not loaded" % k)
state_dict[k] = v
for k in state_dict:
if k not in target_state_dict:
# print(target_state_dict.keys())
print("removing: %s not used" % k)
# state_dict.pop(k)
else:
source_state_dict[k] = state_dict[k]
# if "pred.weight" in state_dict:
# state_dict.pop("pred.bias")
# state_dict.pop("pred.weight")
model.load_state_dict(source_state_dict)
return