in vision/run_weak_strong.py [0:0]
def get_model(name):
if name == "alexnet":
model = alexnet()
elif name == "resnet50_dino":
model = resnet50_dino()
elif name == "vitb8_dino":
model = vitb8_dino()
else:
raise ValueError(f"Unknown model {name}")
model.cuda()
model.eval()
model = nn.DataParallel(model)
return model