in tensorflow_benchmark/tf_cnn_benchmarks/models/model_config.py [0:0]
def get_model_config(model, dataset):
"""Map model name to model network configuration."""
if 'cifar10' == dataset.name:
return get_cifar10_model_config(model)
if model == 'vgg11':
mc = vgg_model.Vgg11Model()
elif model == 'vgg16':
mc = vgg_model.Vgg16Model()
elif model == 'vgg19':
mc = vgg_model.Vgg19Model()
elif model == 'lenet':
mc = lenet_model.Lenet5Model()
elif model == 'googlenet':
mc = googlenet_model.GooglenetModel()
elif model == 'overfeat':
mc = overfeat_model.OverfeatModel()
elif model == 'alexnet':
mc = alexnet_model.AlexnetModel()
elif model == 'trivial':
mc = trivial_model.TrivialModel()
elif model == 'inception3':
mc = inception_model.Inceptionv3Model()
elif model == 'inception4':
mc = inception_model.Inceptionv4Model()
elif model == 'resnet50' or model == 'resnet50_v2':
mc = resnet_model.ResnetModel(model, (3, 4, 6, 3))
elif model == 'resnet101' or model == 'resnet101_v2':
mc = resnet_model.ResnetModel(model, (3, 4, 23, 3))
elif model == 'resnet152' or model == 'resnet152_v2':
mc = resnet_model.ResnetModel(model, (3, 8, 36, 3))
else:
raise KeyError('Invalid model name \'%s\' for dataset \'%s\'' %
(model, dataset.name))
return mc