in svhn_based_exp/utils.py [0:0]
def get_upstream_model(upstream_setting, linear_base=None, save_dir="./checkpoint/", device="cuda"):
(upstream_task, upstream_para_to_vary_model, upstream_model, upstream_seed, repre, model_specify) = tuple(upstream_setting)
upstream_num_classes = get_num_classes(upstream_task)
upstream_net = get_model(upstream_model, upstream_num_classes, device=device, linear_base=linear_base)
checkpoint = torch.load(f'{save_dir}/checkpoint/svhn_{upstream_model}_task_{upstream_task}_upstream_setting_None_para_to_vary_model_{upstream_para_to_vary_model}_seed_{upstream_seed}.pth')
if model_specify == "last":
upstream_net.load_state_dict(checkpoint['net'], strict=False)
elif model_specify == "best":
upstream_net.load_state_dict(checkpoint['best_state']['net'], strict=False)
else:
print("invalid upstream model specification")
exit(0)
if repre == "feat":
if upstream_model.startswith("convnet"):
upstream_net.fc1 = nn.Identity()
else:
upstream_net.fc = nn.Identity()
return upstream_net