def model_fn()

in src/inference.py [0:0]


def model_fn(model_dir):
    logger.info("In model_fn. Model directory is -")
    logger.info(model_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, model_args = load_model(os.path.join(model_dir, "model.pt"),
                                   extra_logging=True)
    model.to(device)
    model.set_decode_type("greedy")
    model.eval()
    return model