def get_model_func()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


def get_model_func(model_name):
    if model_name.startswith("resnet"):
        nlayer = int(model_name[len("resnet") :])
        return lambda images, *args, **kwargs: inference_resnet_v1(images, nlayer, *args, **kwargs)
    else:
        raise ValueError("Invalid model type: %s" % model_name)