tools/incremental_test/batch.py (148 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json import logging import traceback from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass from typing import Any, Dict, Iterable, List from .environment import Environment from .runner import ( ProfileLogs, ResultComparison, benchmark_server, compare_server_to_full, ) from .specification import Specification LOG: logging.Logger = logging.getLogger(__name__) @dataclass class Sample: integers: Dict[str, int] normals: Dict[str, str] class RunnerResult(ABC): _input: Specification def __init__(self, input: Specification) -> None: self._input = input @property def input(self) -> Specification: return self._input @abstractmethod def get_status(self) -> str: raise NotImplementedError @abstractmethod def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]: raise NotImplementedError @abstractmethod def to_logger_sample(self) -> Sample: raise NotImplementedError class ExceptionalRunnerResult(RunnerResult): _trace: str def __init__(self, input: Specification, trace: str) -> None: super().__init__(input) self._trace = trace def get_status(self) -> str: return "exception" def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]: return {"status": self.get_status(), "trace": self._trace} def to_logger_sample(self) -> Sample: return Sample( normals={ "status": self.get_status(), "input": json.dumps(self.input.to_json()), "exception": self._trace, }, integers={}, ) class FinishedRunnerResult(RunnerResult, metaclass=ABCMeta): _output: ResultComparison def __init__(self, input: Specification, output: ResultComparison) -> None: super().__init__(input) self._output = output def to_logger_sample(self) -> Sample: full_check_time = self._output.profile_logs.full_check_time() incremental_check_time = ( self._output.profile_logs.total_incremental_check_time() ) return Sample( normals={ "status": self.get_status(), "input": json.dumps(self.input.to_json()), }, integers={ "full_check_time": full_check_time, "incremental_check_time": incremental_check_time, }, ) class PassedRunnerResult(FinishedRunnerResult): def get_status(self) -> str: return "pass" def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]: # Don't bother include the input specification in the result if the test passes. return { "status": self.get_status(), "output": self._output.to_json(dont_show_discrepancy), } class FailedRunnerResult(FinishedRunnerResult): def get_status(self) -> str: return "fail" def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]: return { "status": self.get_status(), "input": self.input.to_json(), "output": self._output.to_json(dont_show_discrepancy), } class BenchmarkResult(RunnerResult): _profile_logs: ProfileLogs def __init__(self, input: Specification, profile_logs: ProfileLogs) -> None: super().__init__(input) self._profile_logs = profile_logs def get_status(self) -> str: return "benchmark" def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]: return { "status": self.get_status(), "input": self.input.to_json(), "time": self._profile_logs.total_incremental_check_time(), "profile_logs": self._profile_logs.to_json(), } def to_logger_sample(self) -> Sample: incremental_check_time = self._profile_logs.total_incremental_check_time() return Sample( normals={ "status": self.get_status(), "input": json.dumps(self.input.to_json()), }, integers={"incremental_check_time": incremental_check_time}, ) def profile_logs(self) -> ProfileLogs: return self._profile_logs def run_single_test(environment: Environment, input: Specification) -> RunnerResult: try: LOG.info(f"Running test on state '{input.old_state}' vs '{input.new_state}'") output = compare_server_to_full(environment, input) if output.discrepancy is None: result = PassedRunnerResult(input, output) else: result = FailedRunnerResult(input, output) except Exception: result = ExceptionalRunnerResult(input, traceback.format_exc()) LOG.info(f"Test finished with status = {result.get_status()}") return result def run_batch_test( environment: Environment, inputs: Iterable[Specification] ) -> List[RunnerResult]: return [run_single_test(environment, input) for input in inputs] def run_single_benchmark( environment: Environment, input: Specification ) -> RunnerResult: try: LOG.info( f"Running benchmark on state '{input.old_state}' vs '{input.new_state}'" ) output = benchmark_server(environment, input) result = BenchmarkResult(input, output) except Exception: result = ExceptionalRunnerResult(input, traceback.format_exc()) LOG.info(f"Test finished with status = {result.get_status()}") return result def run_batch_benchmark( environment: Environment, inputs: Iterable[Specification] ) -> List[RunnerResult]: return [run_single_benchmark(environment, input) for input in inputs]