in src/inference.py [0:0]
def model_fn(model_dir):
logger.info("In model_fn. Model directory is -")
logger.info(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, model_args = load_model(os.path.join(model_dir, "model.pt"),
extra_logging=True)
model.to(device)
model.set_decode_type("greedy")
model.eval()
return model