tools/incremental_test/runner.py (250 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import os
import tempfile
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from time import sleep
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, overload
from typing_extensions import Final, Literal
from .environment import Environment
from .specification import Specification
LOG: logging.Logger = logging.getLogger(__name__)
class MalformedPyreOutputException(Exception):
pass
@dataclass
class PyreError:
line: int
column: int
path: str
description: str
@staticmethod
def from_json(input_json: Dict[str, Any]) -> "PyreError":
try:
return PyreError(
line=int(input_json["line"]),
column=int(input_json["column"]),
path=input_json["path"],
description=input_json["description"],
)
except KeyError as key:
raise MalformedPyreOutputException(
f"Cannot interpret pyre output: missing key '{key}'"
)
class PyreRunner:
def __init__(
self,
environment: Environment,
specification: Specification,
working_directory: Path,
) -> None:
self._environment = environment
self._specification = specification
self._working_directory = working_directory
invocation = self._environment.pyre_client_override or "pyre"
binary_override = self._environment.pyre_binary_override
if binary_override:
invocation += f" --binary {binary_override}"
typeshed_override = self._environment.typeshed_override
if typeshed_override:
invocation += f" --typeshed {typeshed_override}"
self._pyre_invocation: str = invocation
def update(self) -> List[Mapping[str, int]]:
incremental_update_logs: List[Mapping[str, int]] = []
new_state = self._specification.new_state
updates = new_state.update_steps()
for expected, update in enumerate(updates):
update.update(self._environment, self._working_directory)
while True:
incremental_update_logs = self.run_profile("incremental_updates")
if len(incremental_update_logs) > expected:
break
else:
sleep(1)
return incremental_update_logs
def run_check(self) -> List[PyreError]:
pyre_check_command = (
f"{self._pyre_invocation} {self._specification.pyre_check_pyre_options} "
"--output=json "
"--noninteractive "
f"check {self._specification.pyre_check_options}"
).rstrip()
output = self._environment.checked_run(
working_directory=self._working_directory,
command=pyre_check_command,
expected_return_codes=(0, 1),
)
if output.return_code == 0:
return []
else:
return [PyreError.from_json(x) for x in json.loads(output.stdout)]
def run_start(self) -> Mapping[str, int]:
pyre_start_command = (
# Use `pyre restart` instead of `pyre start` as the we want to:
# - Kill existing servers
# - Force the initial check to finish
f"{self._pyre_invocation} {self._specification.pyre_start_pyre_options} "
"--no-saved-state --enable-profiling "
f"restart {self._specification.pyre_start_options}"
).rstrip()
self._environment.checked_run(
working_directory=self._working_directory,
command=pyre_start_command,
expected_return_codes=(0, 1),
)
cold_start_time_phases = self.run_profile("cold_start_phases")
shared_memory_over_time = self.run_profile("total_shared_memory_size_over_time")
_, cold_start_total_memory = shared_memory_over_time[0]
with tempfile.NamedTemporaryFile() as temporary_file:
self._environment.checked_run(
working_directory=self._working_directory,
command=(
f"{self._pyre_invocation} "
f"query save_server_state('{temporary_file.name}')"
),
)
saved_state_size = os.stat(temporary_file.name).st_size
return {
**cold_start_time_phases,
"heap_size": cold_start_total_memory,
"saved_state_size": saved_state_size,
}
def run_stop(self) -> None:
self._environment.checked_run(
working_directory=self._working_directory,
command=(
f"{self._pyre_invocation} {self._specification.pyre_stop_pyre_options} "
f"stop {self._specification.pyre_stop_options}"
),
)
def run_incremental(self) -> List[PyreError]:
pyre_incremental_command = (
f"{self._pyre_invocation} "
f"{self._specification.pyre_incremental_pyre_options} "
"--output=json "
"--noninteractive "
f"incremental {self._specification.pyre_incremental_options}"
).rstrip()
output = self._environment.checked_run(
working_directory=self._working_directory,
command=pyre_incremental_command,
expected_return_codes=(0, 1),
)
if output.return_code == 0:
return []
else:
return [PyreError.from_json(x) for x in json.loads(output.stdout)]
@overload
def run_profile(
self, output_kind: Literal["incremental_updates"]
) -> List[Mapping[str, int]]:
...
@overload # noqa T20027161
def run_profile(
self, output_kind: Literal["cold_start_phases"]
) -> Mapping[str, int]:
...
@overload # noqa T20027161
def run_profile(
self, output_kind: Literal["total_shared_memory_size_over_time"]
) -> List[Tuple[str, int]]:
...
def run_profile(self, output_kind: str) -> object: # noqa T20027161
pyre_profile_command = (
f"{self._pyre_invocation} " f"profile --profile-output={output_kind}"
).rstrip()
output = self._environment.checked_run(
working_directory=self._working_directory, command=pyre_profile_command
)
return json.loads(output.stdout)
@contextmanager
def _create_pyre_runner(
environment: Environment, specification: Specification
) -> Iterator["PyreRunner"]:
with specification.old_state.activate_sandbox(environment) as sandbox_root:
yield PyreRunner(environment, specification, sandbox_root)
@dataclass
class InconsistentOutput:
full_check_output: List[PyreError]
incremental_check_output: List[PyreError]
def to_json(self) -> Dict[str, Any]:
return {
"full_check_output": [asdict(e) for e in self.full_check_output],
"incremental_check_output": [
asdict(e) for e in self.incremental_check_output
],
}
@dataclass
class ProfileLogs:
incremental_update_logs: List[Mapping[str, int]]
cold_start_log: Mapping[str, int]
def to_json(self) -> Dict[str, Any]:
return {
"incremental_update_logs": self.incremental_update_logs,
"cold_start_log": self.cold_start_log,
}
def total_incremental_check_time(self) -> int:
return sum(log["total"] for log in self.incremental_update_logs) // 1000
def full_check_time(self) -> int:
return sum(duration for _, duration in self.cold_start_log.items()) // 1000
@dataclass
class ResultComparison:
discrepancy: Final[Optional[InconsistentOutput]]
profile_logs: ProfileLogs
def to_json(self, dont_show_discrepancy: bool = False) -> Dict[str, Any]:
result: Dict[str, Any] = {
"full_check_time": self.profile_logs.full_check_time(),
"incremental_check_time": self.profile_logs.total_incremental_check_time(),
"profile_logs": self.profile_logs.to_json(),
}
discrepancy = self.discrepancy
if dont_show_discrepancy:
return result
else:
result["discrepancy"] = (
"none" if discrepancy is None else discrepancy.to_json()
)
return result
def compare_server_to_full(
environment: Environment, specification: Specification
) -> ResultComparison:
LOG.info("Preparing base repository state...")
with _create_pyre_runner(environment, specification) as pyre_runner:
LOG.debug("Starting pyre server...")
cold_start_log = pyre_runner.run_start()
LOG.debug("Preparing updated repository state...")
incremental_update_logs = pyre_runner.update()
LOG.info("Running pyre incremental check...")
incremental_check_output = pyre_runner.run_incremental()
LOG.debug("Stopping pyre server...")
pyre_runner.run_stop()
LOG.info(
f"Pyre incremental check successfully finished (with {len(incremental_check_output)} errors)." # noqa: line too long
)
LOG.info("Running pyre full check...")
full_check_output = pyre_runner.run_check()
LOG.info(
f"Pyre full check successfully finished (with {len(full_check_output)} errors)." # noqa: line too long
)
discrepancy = (
None
if incremental_check_output == full_check_output
else InconsistentOutput(full_check_output, incremental_check_output)
)
profile_logs = ProfileLogs(incremental_update_logs, cold_start_log)
return ResultComparison(discrepancy, profile_logs)
def benchmark_server(
environment: Environment, specification: Specification
) -> ProfileLogs:
LOG.info("Preparing base repository state...")
with _create_pyre_runner(environment, specification) as pyre_runner:
LOG.debug("Starting pyre server...")
cold_start_log = pyre_runner.run_start()
LOG.debug("Preparing updated repository state...")
incremental_update_logs = pyre_runner.update()
LOG.info("Running pyre incremental check...")
incremental_check_output = pyre_runner.run_incremental()
LOG.debug("Stopping pyre server...")
pyre_runner.run_stop()
LOG.info(
f"Pyre incremental check successfully finished (with {len(incremental_check_output)} errors)." # noqa: line too long
)
return ProfileLogs(incremental_update_logs, cold_start_log)