in src/inference_pytorch_neo.py [0:0]
def model_fn(model_dir):
logger.info('model_fn')
with torch.neo.config(model_dir=model_dir, neo_runtime=True):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# The compiled model is saved as "compiled.pt"
model = torch.jit.load(os.path.join(model_dir, 'compiled.pt'))
model = model.to(device)
# It is recommended to run warm-up inference during model load
sample_input_path = os.path.join(model_dir, 'sample_input.pkl')
with open(sample_input_path, 'rb') as input_file:
model_input = pickle.load(input_file)
if torch.is_tensor(model_input):
model_input = model_input.to(device)
model(model_input)
elif isinstance(model_input, tuple):
model_input = (inp.to(device)
for inp in model_input if torch.is_tensor(inp))
model(*model_input)
else:
print("Only supports a torch tensor or a tuple of torch tensors")
return model