torchbenchmark/util/backends/torchdynamo.py (15 lines of code) (raw):
"""
Support TorchDynamo(https://github.com/facebookresearch/torchdynamo) backends
"""
import argparse
import functools
from typing import List
def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dyamo_args: List[str]) -> argparse.Namespace:
import torchdynamo
parser = argparse.ArgumentParser()
parser.add_argument(
"--torchdynamo", choices=torchdynamo.list_backends(), help="Specify torchdynamo backends"
)
args = parser.parse_args(dyamo_args)
return args
def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace):
import torchdynamo
model.add_context(functools.partial(torchdynamo.optimize, args.torchdynamo))
torchdynamo.reset()