def get_model_config()

in tensorflow_benchmark/tf_cnn_benchmarks/models/model_config.py [0:0]


def get_model_config(model, dataset):
  """Map model name to model network configuration."""
  if 'cifar10' == dataset.name:
    return get_cifar10_model_config(model)
  if model == 'vgg11':
    mc = vgg_model.Vgg11Model()
  elif model == 'vgg16':
    mc = vgg_model.Vgg16Model()
  elif model == 'vgg19':
    mc = vgg_model.Vgg19Model()
  elif model == 'lenet':
    mc = lenet_model.Lenet5Model()
  elif model == 'googlenet':
    mc = googlenet_model.GooglenetModel()
  elif model == 'overfeat':
    mc = overfeat_model.OverfeatModel()
  elif model == 'alexnet':
    mc = alexnet_model.AlexnetModel()
  elif model == 'trivial':
    mc = trivial_model.TrivialModel()
  elif model == 'inception3':
    mc = inception_model.Inceptionv3Model()
  elif model == 'inception4':
    mc = inception_model.Inceptionv4Model()
  elif model == 'resnet50' or model == 'resnet50_v2':
    mc = resnet_model.ResnetModel(model, (3, 4, 6, 3))
  elif model == 'resnet101' or model == 'resnet101_v2':
    mc = resnet_model.ResnetModel(model, (3, 4, 23, 3))
  elif model == 'resnet152' or model == 'resnet152_v2':
    mc = resnet_model.ResnetModel(model, (3, 8, 36, 3))
  else:
    raise KeyError('Invalid model name \'%s\' for dataset \'%s\'' %
                   (model, dataset.name))
  return mc