optimum_benchmark/benchmark/base.py (52 lines of code) (raw):
from dataclasses import dataclass
from logging import getLogger
from typing import TYPE_CHECKING, Type
from hydra.utils import get_class
from ..backends.config import BackendConfig
from ..hub_utils import PushToHubMixin, classproperty
from ..launchers import LauncherConfig
from ..scenarios import ScenarioConfig
from .config import BenchmarkConfig
from .report import BenchmarkReport
if TYPE_CHECKING:
from ..backends.base import Backend
from ..launchers.base import Launcher
from ..scenarios.base import Scenario
LOGGER = getLogger("benchmark")
@dataclass
class Benchmark(PushToHubMixin):
config: BenchmarkConfig
report: BenchmarkReport
def __post_init__(self):
if isinstance(self.config, dict):
self.config = BenchmarkConfig.from_dict(self.config)
elif not isinstance(self.config, BenchmarkConfig):
raise ValueError("config must be either a dict or a BenchmarkConfig instance")
if isinstance(self.report, dict):
self.report = BenchmarkReport.from_dict(self.report)
elif not isinstance(self.report, BenchmarkReport):
raise ValueError("report must be either a dict or a BenchmarkReport instance")
@staticmethod
def launch(config: BenchmarkConfig):
"""
Runs an benchmark using specified launcher configuration/logic
"""
# Allocate requested launcher
launcher_config: LauncherConfig = config.launcher
launcher_factory: Type[Launcher] = get_class(launcher_config._target_)
launcher: Launcher = launcher_factory(launcher_config)
# Launch the benchmark using the launcher
report = launcher.launch(worker=Benchmark.run, worker_args=[config])
if config.log_report:
report.log()
if config.print_report:
report.print()
return report
@staticmethod
def run(config: BenchmarkConfig):
"""
Runs a scenario using specified backend configuration/logic
"""
# Allocate requested backend
backend_config: BackendConfig = config.backend
backend_factory: Type[Backend] = get_class(backend_config._target_)
backend: Backend = backend_factory(backend_config)
# Allocate requested scenario
scenario_config: ScenarioConfig = config.scenario
scenario_factory: Type[Scenario] = get_class(scenario_config._target_)
scenario: Scenario = scenario_factory(scenario_config)
# Run the scenario using the backend
report = scenario.run(backend)
return report
@classproperty
def default_filename(cls) -> str:
return "benchmark.json"