api/connection.py (104 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import enum
import json
import logging
import subprocess
from pathlib import Path
from types import TracebackType
from typing import Any, List, NamedTuple, Optional
from typing_extensions import TypedDict
LOG: logging.Logger = logging.getLogger(__name__)
class ExitCode(enum.IntEnum):
SUCCESS = 0
FOUND_ERRORS = 1
# See client/commands/command.py for more exit codes.
# We use NamedTuple instead of dataclasses for Python3.5/6 support.
class PyreCheckResult(NamedTuple):
exit_code: int
errors: Optional[List[str]]
# pyre-ignore[33]: We don't have GADT's yet.
class PyreQueryResult(TypedDict):
response: Any
class PyreQueryError(Exception):
pass
class PyreConnection:
def __init__(
self,
pyre_directory: Optional[Path] = None,
pyre_arguments: Optional[List[str]] = None,
) -> None:
self.pyre_directory: Path = (
pyre_directory if pyre_directory is not None else Path.cwd()
)
self.pyre_arguments: List[str] = pyre_arguments or []
self.server_initialized = False
def __enter__(self) -> "PyreConnection":
self.start_server()
return self
def __exit__(
self,
_type: Optional[BaseException],
_value: Optional[BaseException],
_traceback: Optional[TracebackType],
) -> None:
self.stop_server()
return None
def add_arguments(self, *arguments: str) -> None:
self.pyre_arguments += arguments
def start_server(self) -> PyreCheckResult:
# incremental will start a new server when needed.
result = subprocess.run(
["pyre", "--noninteractive", *self.pyre_arguments, "incremental"],
stdout=subprocess.PIPE,
cwd=str(self.pyre_directory),
)
self.server_initialized = True
return _parse_check_output(result)
def restart_server(self) -> PyreCheckResult:
result = _parse_check_output(
subprocess.run(
["pyre", "--noninteractive", *self.pyre_arguments, "restart"],
stdout=subprocess.PIPE,
cwd=str(self.pyre_directory),
)
)
self.server_initialized = True
return result
def stop_server(self, ignore_errors: bool = False) -> None:
subprocess.run(
["pyre", "--noninteractive", *self.pyre_arguments, "stop"],
check=not ignore_errors,
cwd=str(self.pyre_directory),
)
@staticmethod
def _validate_query_response(response: str) -> PyreQueryResult:
try:
response = json.loads(response)
except json.decoder.JSONDecodeError as decode_error:
raise PyreQueryError(f"`{response} is not valid JSON.") from decode_error
if "error" in response:
raise PyreQueryError(response["error"])
if "response" not in response:
raise PyreQueryError(
'The server response is invalid: It does not contain an "error" or'
f'"response" field. Response: `{response}`."'
)
return response
def query_server(self, query: str) -> PyreQueryResult:
if not self.server_initialized:
result = self.start_server()
if result.exit_code not in (ExitCode.SUCCESS, ExitCode.FOUND_ERRORS):
raise PyreQueryError(
f"Error while starting a pyre server, Pyre exited with a code of {result.exit_code}."
)
LOG.debug(f"Running query: `pyre query '{query}'`")
result = subprocess.run(
["pyre", "--noninteractive", *self.pyre_arguments, "query", query],
stdout=subprocess.PIPE,
cwd=str(self.pyre_directory),
)
if result.returncode != 0:
raise PyreQueryError(
f"Error while running query, Pyre exited with a code of {result.returncode}."
)
return self._validate_query_response(result.stdout.decode())
def _parse_check_output(
completed_process: "subprocess.CompletedProcess[bytes]",
) -> PyreCheckResult:
errors = completed_process.stdout.decode().split()
return PyreCheckResult(exit_code=completed_process.returncode, errors=errors)