torchbenchmark/util/framework/huggingface/args.py (23 lines of code) (raw):

import argparse import torch from torchbenchmark.util.model import BenchmarkModel from typing import List, Dict, Tuple def add_bool_arg(parser: argparse.ArgumentParser, name: str, default_value: bool=True): group = parser.add_mutually_exclusive_group(required=False) group.add_argument('--' + name, dest=name, action='store_true') group.add_argument('--no-' + name, dest=name, action='store_false') parser.set_defaults(**{name: default_value}) def parse_args(model: BenchmarkModel, extra_args: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser() # by default, enable half precision for inference add_bool_arg(parser, "eval_fp16", default_value=True) args = parser.parse_args(extra_args) args.device = model.device args.jit = model.jit # disable fp16 when device is CPU if args.device == "cpu": args.eval_fp16 = False return args def apply_args(model: BenchmarkModel, args: argparse.Namespace): # apply eval_fp16 if args.eval_fp16: model.model, model.example_inputs = enable_eval_fp16(model.model, model.example_inputs) def enable_eval_fp16(model: torch.nn.Module, example_input: Dict[str, torch.tensor]) -> Tuple[torch.nn.Module, Dict[str, torch.tensor]]: return model.half(), {'input_ids': example_input['input_ids'].half()}