def get_model()

in vision/run_weak_strong.py [0:0]


def get_model(name):
    if name == "alexnet":
        model = alexnet()
    elif name == "resnet50_dino":
        model = resnet50_dino()
    elif name == "vitb8_dino":
        model = vitb8_dino()
    else:
        raise ValueError(f"Unknown model {name}")
    model.cuda()
    model.eval()
    model = nn.DataParallel(model)
    return model