in tensorflow_benchmark/tf_cnn_benchmarks/models/model_config.py [0:0]
def get_cifar10_model_config(model):
"""Map model name to model network configuration for cifar10 dataset."""
if model == 'alexnet':
mc = alexnet_model.AlexnetCifar10Model()
elif model == 'resnet20' or model == 'resnet20_v2':
mc = resnet_model.ResnetCifar10Model(model, (3, 3, 3))
elif model == 'resnet32' or model == 'resnet32_v2':
mc = resnet_model.ResnetCifar10Model(model, (5, 5, 5))
elif model == 'resnet44' or model == 'resnet44_v2':
mc = resnet_model.ResnetCifar10Model(model, (7, 7, 7))
elif model == 'resnet56' or model == 'resnet56_v2':
mc = resnet_model.ResnetCifar10Model(model, (9, 9, 9))
elif model == 'resnet110' or model == 'resnet110_v2':
mc = resnet_model.ResnetCifar10Model(model, (18, 18, 18))
elif model == 'densenet40_k12':
mc = densenet_model.DensenetCifar10Model(model, (12, 12, 12), 12)
elif model == 'densenet100_k12':
mc = densenet_model.DensenetCifar10Model(model, (32, 32, 32), 12)
elif model == 'densenet100_k24':
mc = densenet_model.DensenetCifar10Model(model, (32, 32, 32), 24)
else:
raise KeyError('Invalid model name \'%s\' for Cifar10 DataSet.' % model)
return mc