def get_model()

in svhn_based_exp/utils.py [0:0]


def get_model(model, num_classes, linear_base=None, device="cuda"):
    if model.startswith('convnet'):
        net = ConvNet(num_classes=num_classes)
        if len(model.split("-")) >= 2:
            hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
            net.fc1 = MLP(net.fc1.in_features, net.fc1.out_features, hidden_layers)
    elif model.startswith("resnet18"):
        net = torchvision.models.resnet18(num_classes=num_classes)
        if len(model.split("-")) >= 2:
            hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
            net.fc = MLP(net.fc.in_features, net.fc.out_features, hidden_layers)
    elif model == "linear":
        net = nn.Linear(in_features=linear_base, out_features=num_classes)
    elif model.startswith("MLP"):
        hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
        net = MLP(linear_base, num_classes, hidden_layers)
    return net.to(device)