vision/models.py (20 lines of code) (raw):
import torch
import torchvision
class HeadAndEmbedding(torch.nn.Module):
def __init__(self, head):
super(HeadAndEmbedding, self).__init__()
self.head = head
def forward(self, x):
return x, self.head(x)
def _alexnet_replace_fc(model):
model.classifier = HeadAndEmbedding(model.classifier)
return model
def resnet50_dino():
model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50")
return model
def vitb8_dino():
model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8")
return model
def alexnet():
model = torchvision.models.alexnet(pretrained=True)
return _alexnet_replace_fc(model)