in utils.py [0:0]
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
url = None
if model_name == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
elif model_name == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
elif model_name == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
elif model_name == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
elif model_name == "resnet50":
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
if url is not None:
print("We load the reference pretrained linear weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
linear_classifier.load_state_dict(state_dict, strict=True)
else:
print("We use random linear weights.")