in svhn_based_exp/utils.py [0:0]
def get_model(model, num_classes, linear_base=None, device="cuda"):
if model.startswith('convnet'):
net = ConvNet(num_classes=num_classes)
if len(model.split("-")) >= 2:
hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
net.fc1 = MLP(net.fc1.in_features, net.fc1.out_features, hidden_layers)
elif model.startswith("resnet18"):
net = torchvision.models.resnet18(num_classes=num_classes)
if len(model.split("-")) >= 2:
hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
net.fc = MLP(net.fc.in_features, net.fc.out_features, hidden_layers)
elif model == "linear":
net = nn.Linear(in_features=linear_base, out_features=num_classes)
elif model.startswith("MLP"):
hidden_layers = [int(num_h) for num_h in model.split("-")[1:]]
net = MLP(linear_base, num_classes, hidden_layers)
return net.to(device)