torchbenchmark/util/backends/torch_trt.py (14 lines of code) (raw):
import torch
from typing import Tuple
def enable_torchtrt(precision: str, model: torch.nn.Module, example_inputs: Tuple[torch.tensor]) -> torch.nn.Module:
import torch_tensorrt
if precision == "fp16":
torchtrt_dtype = torch_tensorrt.dtype.half
torch_dtype = torch.half
elif precision == "fp32":
torchtrt_dtype = torch_tensorrt.dtype.float
torch_dtype = torch.float32
else:
raise NotImplementedError("torch_tensorrt only supports fp32 or fp16 precision")
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
return torch_tensorrt.compile(model, inputs=trt_input, enabled_precisions=torchtrt_dtype)