torchbenchmark/util/extra_args.py (94 lines of code) (raw):

import argparse from typing import List, Optional, Tuple from torchbenchmark.util.backends.fx2trt import enable_fx2trt from torchbenchmark.util.backends.jit import enable_jit from torchbenchmark.util.backends.torch_trt import enable_torchtrt def enable_opt_args(opt_args: argparse.Namespace) -> bool: "Check if any of the optimizations is enabled." opt_args_dict = vars(opt_args) for k in opt_args_dict: if opt_args_dict[k]: return True return False def add_bool_arg(parser: argparse.ArgumentParser, name: str, default_value: bool=True): group = parser.add_mutually_exclusive_group(required=False) group.add_argument('--' + name, dest=name, action='store_true') group.add_argument('--no-' + name, dest=name, action='store_false') parser.set_defaults(**{name: default_value}) def is_timm_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool: return hasattr(model, 'TIMM_MODEL') and model.TIMM_MODEL def is_torchvision_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool: return hasattr(model, 'TORCHVISION_MODEL') and model.TORCHVISION_MODEL def is_hf_model(model: 'torchbenchmark.util.model.BenchmarkModel') -> bool: return hasattr(model, 'HF_MODEL') and model.HF_MODEL def get_hf_maxlength(model: 'torchbenchmark.util.model.BenchmarkModel') -> Optional[int]: return model.max_length if is_hf_model(model) else None def check_fp16(model: 'torchbenchmark.util.model.BenchmarkModel', fp16: str) -> bool: if fp16 == "half": return (is_torchvision_model(model) or is_hf_model(model) or is_timm_model(model)) and model.test == 'eval' and model.device == 'cuda' if fp16 == "amp": is_cuda_eval_test = (model.test == 'eval' and model.device == 'cuda') support_amp = hasattr(model, "enable_amp") return is_cuda_eval_test or support_amp return True # torchvision models uses fp16 half mode by default, others use fp32 def get_fp16_default(model: 'torchbenchmark.util.model.BenchmarkModel') -> str: if (is_torchvision_model(model) or is_hf_model(model) or is_timm_model(model)) and model.test == 'eval' and model.device == 'cuda': return "half" return "no" def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', extra_args: List[str]) -> Tuple[argparse.Namespace, List[str]]: parser = argparse.ArgumentParser() parser.add_argument("--fp16", choices=["no", "half", "amp"], default=get_fp16_default(model), help="enable fp16 modes from: no fp16, half, or amp") dargs, opt_args = parser.parse_known_args(extra_args) if not check_fp16(model, dargs.fp16): raise NotImplementedError(f"fp16 value: {dargs.fp16}, fp16 (amp mode) is only supported by CUDA inference tests, " f"fp16 (half mode) is only supported by torchvision CUDA inference tests.") return (dargs, opt_args) def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dargs: argparse.Namespace): if dargs.fp16 and not dargs.fp16 == "no": if dargs.fp16 == "half": assert hasattr(model, "enable_fp16_half"), "Model doesn't have method 'enable_fp16_half'. Please report a bug. " model.enable_fp16_half() elif dargs.fp16 == "amp": # model can handle amp if it has 'enable_amp' callback function if hasattr(model, "enable_amp"): model.enable_amp() else: import torch model.add_context(lambda: torch.cuda.amp.autocast(dtype=torch.float16)) else: assert False, f"Get invalid fp16 value: {dargs.fp16}. Please report a bug." # Dispatch arguments based on model type def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--fx2trt", action='store_true', help="enable fx2trt") parser.add_argument("--fuser", type=str, default="", choices=["fuser0", "fuser1", "fuser2"], help="enable fuser") parser.add_argument("--torch_trt", action='store_true', help="enable torch_tensorrt") args = parser.parse_args(opt_args) args.jit = model.jit if model.device == "cpu" and args.fuser: raise NotImplementedError("Fuser only works with GPU.") if not (model.device == "cuda" and model.test == "eval"): if args.fx2trt or args.torch_trt: raise NotImplementedError("TensorRT only works for CUDA inference tests.") if hasattr(model, 'TORCHVISION_MODEL') and model.TORCHVISION_MODEL: args.cudagraph = False return args def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace): if args.fuser: import torch model.add_context(lambda: torch.jit.fuser(args.fuser)) if args.jit: # model can handle jit code themselves through the 'jit_callback' callback function if hasattr(model, 'jit_callback'): model.jit_callback() else: # if model doesn't have customized jit code, use the default jit script code module, exmaple_inputs = model.get_module() model.set_module(enable_jit(model=module, example_inputs=exmaple_inputs, test=model.test)) if args.fx2trt: if args.jit: raise NotImplementedError("fx2trt with JIT is not available.") module, exmaple_inputs = model.get_module() fp16 = not (model.dargs.fp16 == "no") model.set_module(enable_fx2trt(model.batch_size, fp16=fp16, model=module, example_inputs=exmaple_inputs, is_hf_model=is_hf_model(model), hf_max_length=get_hf_maxlength(model))) if args.torch_trt: module, exmaple_inputs = model.get_module() precision = 'fp16' if not model.dargs.fp16 == "no" else 'fp32' model.set_module(enable_torchtrt(precision=precision, model=module, example_inputs=exmaple_inputs))