def model_fn()

in src/inference_pytorch_neo.py [0:0]


def model_fn(model_dir):

    logger.info('model_fn')
    with torch.neo.config(model_dir=model_dir, neo_runtime=True):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # The compiled model is saved as "compiled.pt"
        model = torch.jit.load(os.path.join(model_dir, 'compiled.pt'))
        model = model.to(device)

        # It is recommended to run warm-up inference during model load
        sample_input_path = os.path.join(model_dir, 'sample_input.pkl')
        with open(sample_input_path, 'rb') as input_file:
            model_input = pickle.load(input_file)
        if torch.is_tensor(model_input):
            model_input = model_input.to(device)
            model(model_input)
        elif isinstance(model_input, tuple):
            model_input = (inp.to(device)
                           for inp in model_input if torch.is_tensor(inp))
            model(*model_input)
        else:
            print("Only supports a torch tensor or a tuple of torch tensors")

        return model