torchbenchmark/util/e2emodel.py (11 lines of code) (raw):

from typing import Optional, List class E2EBenchmarkModel(): """ A base class for adding models for all e2e models. """ def __init__(self, test: str, batch_size: Optional[int]=None, extra_args: List[str]=[]): self.test = test assert self.test == "train" or self.test == "eval", f"Test must be 'train' or 'eval', but get {self.test}. Please submit a bug report." self.batch_size = batch_size if not self.batch_size: self.batch_size = self.DEFAULT_TRAIN_BSIZE if test == "train" else self.DEFAULT_EVAL_BSIZE # If the model doesn't implement test or eval test # its DEFAULT_TRAIN_BSIZE or DEFAULT_EVAL_BSIZE will still be None if not self.batch_size: raise NotImplementedError(f"Test {test} is not implemented.") self.extra_args = extra_args