torchbenchmark/util/backends/jit.py (11 lines of code) (raw):
import torch
from typing import Tuple
def enable_jit(model: torch.nn.Module, example_inputs: Tuple[torch.Tensor], test: str, optimize_for_inference: bool=True) -> torch.ScriptModule:
if hasattr(torch.jit, '_script_pdt'):
model = torch.jit._script_pdt(model, example_inputs=[example_inputs, ])
else:
model = torch.jit.script(model, example_inputs=[example_inputs, ])
if test == "eval" and optimize_for_inference:
model = torch.jit.optimize_for_inference(model)
assert isinstance(model, torch.jit.ScriptModule)
return model