def load_weight()

in pyhanabi/utils.py [0:0]


def load_weight(model, weight_file, device):
    state_dict = torch.load(weight_file, map_location=device)
    source_state_dict = OrderedDict()
    target_state_dict = model.state_dict()
    for k, v in target_state_dict.items():
        if k not in state_dict:
            print("warning: %s not loaded" % k)
            state_dict[k] = v
    for k in state_dict:
        if k not in target_state_dict:
            # print(target_state_dict.keys())
            print("removing: %s not used" % k)
            # state_dict.pop(k)
        else:
            source_state_dict[k] = state_dict[k]

    # if "pred.weight" in state_dict:
    #     state_dict.pop("pred.bias")
    #     state_dict.pop("pred.weight")

    model.load_state_dict(source_state_dict)
    return