torchbenchmark/util/e2emodel.py (11 lines of code) (raw):
from typing import Optional, List
class E2EBenchmarkModel():
"""
A base class for adding models for all e2e models.
"""
def __init__(self, test: str, batch_size: Optional[int]=None, extra_args: List[str]=[]):
self.test = test
assert self.test == "train" or self.test == "eval", f"Test must be 'train' or 'eval', but get {self.test}. Please submit a bug report."
self.batch_size = batch_size
if not self.batch_size:
self.batch_size = self.DEFAULT_TRAIN_BSIZE if test == "train" else self.DEFAULT_EVAL_BSIZE
# If the model doesn't implement test or eval test
# its DEFAULT_TRAIN_BSIZE or DEFAULT_EVAL_BSIZE will still be None
if not self.batch_size:
raise NotImplementedError(f"Test {test} is not implemented.")
self.extra_args = extra_args