torchbenchmark/util/extra_args.py (94 lines of code) (raw):
import argparse
from typing import List, Optional, Tuple
from torchbenchmark.util.backends.fx2trt import enable_fx2trt
from torchbenchmark.util.backends.jit import enable_jit
from torchbenchmark.util.backends.torch_trt import enable_torchtrt
def enable_opt_args(opt_args: argparse.Namespace) -> bool:
"Check if any of the optimizations is enabled."
opt_args_dict = vars(opt_args)
for k in opt_args_dict:
if opt_args_dict[k]:
return True
return False
def add_bool_arg(parser: argparse.ArgumentParser, name: str, default_value: bool=True):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--' + name, dest=name, action='store_true')
group.add_argument('--no-' + name, dest=name, action='store_false')
parser.set_defaults(**{name: default_value})
def is_timm_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool:
return hasattr(model, 'TIMM_MODEL') and model.TIMM_MODEL
def is_torchvision_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool:
return hasattr(model, 'TORCHVISION_MODEL') and model.TORCHVISION_MODEL
def is_hf_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool:
return hasattr(model, 'HF_MODEL') and model.HF_MODEL
def get_hf_maxlength(model: 'torchbenchmark.util.model.BenchmarkModel') -> Optional[int]:
return model.max_length if is_hf_model(model) else None
def check_fp16(model: 'torchbenchmark.util.model.BenchmarkModel', fp16: str) -> bool:
if fp16 == "half":
return (is_torchvision_model(model) or is_hf_model(model) or is_timm_model(model)) and model.test == 'eval' and model.device == 'cuda'
if fp16 == "amp":
is_cuda_eval_test = (model.test == 'eval' and model.device == 'cuda')
support_amp = hasattr(model, "enable_amp")
return is_cuda_eval_test or support_amp
return True
# torchvision models uses fp16 half mode by default, others use fp32
def get_fp16_default(model: 'torchbenchmark.util.model.BenchmarkModel') -> str:
if (is_torchvision_model(model) or is_hf_model(model) or is_timm_model(model)) and model.test == 'eval' and model.device == 'cuda':
return "half"
return "no"
def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', extra_args: List[str]) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", choices=["no", "half", "amp"], default=get_fp16_default(model), help="enable fp16 modes from: no fp16, half, or amp")
dargs, opt_args = parser.parse_known_args(extra_args)
if not check_fp16(model, dargs.fp16):
raise NotImplementedError(f"fp16 value: {dargs.fp16}, fp16 (amp mode) is only supported by CUDA inference tests, "
f"fp16 (half mode) is only supported by torchvision CUDA inference tests.")
return (dargs, opt_args)
def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dargs: argparse.Namespace):
if dargs.fp16 and not dargs.fp16 == "no":
if dargs.fp16 == "half":
assert hasattr(model, "enable_fp16_half"), "Model doesn't have method 'enable_fp16_half'. Please report a bug. "
model.enable_fp16_half()
elif dargs.fp16 == "amp":
# model can handle amp if it has 'enable_amp' callback function
if hasattr(model, "enable_amp"):
model.enable_amp()
else:
import torch
model.add_context(lambda: torch.cuda.amp.autocast(dtype=torch.float16))
else:
assert False, f"Get invalid fp16 value: {dargs.fp16}. Please report a bug."
# Dispatch arguments based on model type
def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--fx2trt", action='store_true', help="enable fx2trt")
parser.add_argument("--fuser", type=str, default="", choices=["fuser0", "fuser1", "fuser2"], help="enable fuser")
parser.add_argument("--torch_trt", action='store_true', help="enable torch_tensorrt")
args = parser.parse_args(opt_args)
args.jit = model.jit
if model.device == "cpu" and args.fuser:
raise NotImplementedError("Fuser only works with GPU.")
if not (model.device == "cuda" and model.test == "eval"):
if args.fx2trt or args.torch_trt:
raise NotImplementedError("TensorRT only works for CUDA inference tests.")
if hasattr(model, 'TORCHVISION_MODEL') and model.TORCHVISION_MODEL:
args.cudagraph = False
return args
def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace):
if args.fuser:
import torch
model.add_context(lambda: torch.jit.fuser(args.fuser))
if args.jit:
# model can handle jit code themselves through the 'jit_callback' callback function
if hasattr(model, 'jit_callback'):
model.jit_callback()
else:
# if model doesn't have customized jit code, use the default jit script code
module, exmaple_inputs = model.get_module()
model.set_module(enable_jit(model=module, example_inputs=exmaple_inputs, test=model.test))
if args.fx2trt:
if args.jit:
raise NotImplementedError("fx2trt with JIT is not available.")
module, exmaple_inputs = model.get_module()
fp16 = not (model.dargs.fp16 == "no")
model.set_module(enable_fx2trt(model.batch_size, fp16=fp16, model=module, example_inputs=exmaple_inputs,
is_hf_model=is_hf_model(model), hf_max_length=get_hf_maxlength(model)))
if args.torch_trt:
module, exmaple_inputs = model.get_module()
precision = 'fp16' if not model.dargs.fp16 == "no" else 'fp32'
model.set_module(enable_torchtrt(precision=precision, model=module, example_inputs=exmaple_inputs))