torchbenchmark/util/backends/fx2trt.py (22 lines of code) (raw):
import torch
from typing import Tuple, Optional
def enable_fx2trt(max_batch_size: int, fp16: bool, model: torch.nn.Module, example_inputs: Tuple[torch.tensor],
is_hf_model: bool=False, hf_max_length: Optional[int]=None) -> torch.nn.Module:
from fx2trt_oss.fx.lower import lower_to_trt
# special enablement for huggingface models
if is_hf_model:
from transformers.utils.fx import symbolic_trace as hf_symbolic_trace
traced_model = hf_symbolic_trace(
model,
batch_size=max_batch_size,
sequence_length=hf_max_length,
)
return lower_to_trt(
traced_model,
example_inputs,
max_batch_size=max_batch_size,
fp16_mode=fp16,
explicit_batch_dimension=True,
max_workspace_size=20 << 30,
)
return lower_to_trt(module=model, input=example_inputs, \
max_batch_size=max_batch_size, fp16_mode=fp16)