def get_upstream_model()

in svhn_based_exp/utils.py [0:0]


def get_upstream_model(upstream_setting, linear_base=None, save_dir="./checkpoint/", device="cuda"):
    (upstream_task, upstream_para_to_vary_model, upstream_model, upstream_seed, repre, model_specify) = tuple(upstream_setting)
    upstream_num_classes = get_num_classes(upstream_task)
    upstream_net = get_model(upstream_model, upstream_num_classes, device=device, linear_base=linear_base)
    checkpoint = torch.load(f'{save_dir}/checkpoint/svhn_{upstream_model}_task_{upstream_task}_upstream_setting_None_para_to_vary_model_{upstream_para_to_vary_model}_seed_{upstream_seed}.pth')

    if model_specify == "last":
        upstream_net.load_state_dict(checkpoint['net'], strict=False)
    elif model_specify == "best":
        upstream_net.load_state_dict(checkpoint['best_state']['net'], strict=False)
    else:
        print("invalid upstream model specification")
        exit(0)

    if repre == "feat":
        if upstream_model.startswith("convnet"):
            upstream_net.fc1 = nn.Identity()
        else:
            upstream_net.fc = nn.Identity()
    return upstream_net