in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def get_model_func(model_name):
if model_name.startswith("resnet"):
nlayer = int(model_name[len("resnet") :])
return lambda images, *args, **kwargs: inference_resnet_v1(images, nlayer, *args, **kwargs)
else:
raise ValueError("Invalid model type: %s" % model_name)