def get_cifar10_model_config()

in tensorflow_benchmark/tf_cnn_benchmarks/models/model_config.py [0:0]


def get_cifar10_model_config(model):
  """Map model name to model network configuration for cifar10 dataset."""
  if model == 'alexnet':
    mc = alexnet_model.AlexnetCifar10Model()
  elif model == 'resnet20' or model == 'resnet20_v2':
    mc = resnet_model.ResnetCifar10Model(model, (3, 3, 3))
  elif model == 'resnet32' or model == 'resnet32_v2':
    mc = resnet_model.ResnetCifar10Model(model, (5, 5, 5))
  elif model == 'resnet44' or model == 'resnet44_v2':
    mc = resnet_model.ResnetCifar10Model(model, (7, 7, 7))
  elif model == 'resnet56' or model == 'resnet56_v2':
    mc = resnet_model.ResnetCifar10Model(model, (9, 9, 9))
  elif model == 'resnet110' or model == 'resnet110_v2':
    mc = resnet_model.ResnetCifar10Model(model, (18, 18, 18))
  elif model == 'densenet40_k12':
    mc = densenet_model.DensenetCifar10Model(model, (12, 12, 12), 12)
  elif model == 'densenet100_k12':
    mc = densenet_model.DensenetCifar10Model(model, (32, 32, 32), 12)
  elif model == 'densenet100_k24':
    mc = densenet_model.DensenetCifar10Model(model, (32, 32, 32), 24)
  else:
    raise KeyError('Invalid model name \'%s\' for Cifar10 DataSet.' % model)
  return mc