tools/upgrade/configuration.py (238 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json import logging import subprocess from logging import Logger from pathlib import Path from typing import Any, Dict, List, Optional, Sequence from . import UserError from .errors import Errors from .filesystem import get_filesystem LOG: Logger = logging.getLogger(__name__) class Configuration: def __init__( self, path: Path, json_contents: Optional[Dict[str, Any]] = None ) -> None: if json_contents is None: with open(path, "r") as configuration_file: json_contents = json.load(configuration_file) self._path: Path = path if path.name == ".pyre_configuration.local": self.is_local: bool = True else: self.is_local: bool = False self.root: str = str(path.parent) self.original_contents: Dict[str, Any] = json_contents # Configuration fields self.strict: Optional[bool] = json_contents.get("strict") self.targets: Optional[List[str]] = json_contents.get("targets") self.source_directories: Optional[List[str]] = json_contents.get( "source_directories" ) self.version: Optional[str] = json_contents.get("version") self.pysa_version: Optional[str] = json_contents.get("pysa_version") self.use_buck_builder: Optional[bool] = json_contents.get("use_buck_builder") self.use_buck_source_database: Optional[bool] = json_contents.get( "use_buck_source_database" ) self.use_command_v2: Optional[bool] = json_contents.get("use_command_v2") self.ignore_all_errors: Optional[List[str]] = json_contents.get( "ignore_all_errors" ) def get_contents(self) -> Dict[str, Any]: """Assumption: The field names in this class match the key names in the configuration.""" contents: Dict[str, Any] = self.original_contents def update_contents(key: str) -> None: attribute = getattr(self, key) if attribute is not None: contents[key] = attribute elif key in contents: del contents[key] update_contents("targets") update_contents("source_directories") update_contents("version") update_contents("pysa_version") update_contents("strict") update_contents("use_buck_builder") update_contents("use_buck_source_database") update_contents("use_command_v2") return contents @staticmethod def find_parent_file( filename: str, directory: Optional[Path] = None ) -> Optional[Path]: directory = directory or Path.cwd() root = directory.root while directory != root: configuration_path = directory / filename if configuration_path.is_file(): return configuration_path parent = directory.parent if directory == parent: return None directory = parent return None @staticmethod def find_project_configuration(directory: Optional[Path] = None) -> Path: path = Configuration.find_parent_file(".pyre_configuration", directory) if path is None: raise UserError("No root with a `.pyre_configuration` found.") return path @staticmethod def find_local_configuration(directory: Optional[Path] = None) -> Optional[Path]: return Configuration.find_parent_file(".pyre_configuration.local", directory) @staticmethod def gather_local_configuration_paths(directory: str) -> Sequence[Path]: return [ Path(path) for path in get_filesystem().list( directory, patterns=[r"**\.pyre_configuration.local"] ) ] @staticmethod def gather_local_configurations() -> List["Configuration"]: LOG.info("Finding configurations...") configuration_paths = Configuration.gather_local_configuration_paths(".") if not configuration_paths: LOG.info("No projects with local configurations found.") return [] configurations = [] for configuration_path in configuration_paths: with open(configuration_path) as configuration_file: try: configuration = Configuration( configuration_path, json.load(configuration_file) ) configurations.append(configuration) except json.decoder.JSONDecodeError: LOG.error( "Configuration at `%s` is invalid, skipping.", configuration_path, ) LOG.info( "Found %d local configuration%s.", len(configurations), "s" if len(configurations) != 1 else "", ) return configurations def get_path(self) -> Path: return self._path def get_directory(self) -> Path: return self._path.parent def write(self) -> None: with open(self._path, "w") as configuration_file: json.dump(self.get_contents(), configuration_file, sort_keys=True, indent=2) configuration_file.write("\n") def remove_version(self) -> None: if not self.version: LOG.info("Version not found in configuration.") return self.version = None def set_version(self, version: str) -> None: self.version = version def set_pysa_version(self, pysa_version: str) -> None: self.pysa_version = pysa_version def enable_source_database_buck_builder(self) -> None: self.use_buck_builder = True self.use_buck_source_database = True def enable_new_server(self) -> None: self.use_command_v2 = True def add_strict(self) -> None: if self.strict: LOG.info("Configuration is already strict.") return self.strict = True def add_targets(self, targets: List[str]) -> None: existing_targets = self.targets if existing_targets: existing_targets.extend(targets) else: self.targets = targets def deduplicate_targets(self) -> None: targets = self.targets if not targets: return glob_targets = [target for target in targets if target.endswith("/...")] non_glob_targets = [target for target in targets if not target.endswith("/...")] all_targets = sorted(set(glob_targets)) + sorted(set(non_glob_targets)) deduplicated_targets = [] expanded_targets = set() for target in all_targets: if target.endswith("/...") or target.endswith(":"): try: expanded = ( subprocess.check_output(["buck", "query", target]) .decode() .strip() .split("\n") ) if not all(target in expanded_targets for target in expanded): expanded_targets.update(expanded) deduplicated_targets.append(target) except subprocess.CalledProcessError as error: LOG.warning("Failed to query target: %s\n%s", target, str(error)) deduplicated_targets.append(target) elif target not in expanded_targets: expanded_targets.add(target) deduplicated_targets.append(target) deduplicated_targets.sort(key=lambda target: targets.index(target)) self.targets = deduplicated_targets def run_pyre( self, arguments: List[str], description: str, should_clean: bool, command_input: Optional[str], stderr_flag: "subprocess._FILE" = subprocess.PIPE, ) -> Optional["subprocess.CompletedProcess[str]"]: if should_clean: try: # If building targets, run clean or space may run out on device! LOG.info("Running `buck clean`...") subprocess.call(["buck", "clean"], timeout=200) except subprocess.TimeoutExpired: LOG.warning("Buck timed out. Try running `buck kill` before retrying.") return None except subprocess.CalledProcessError as error: LOG.warning("Error calling `buck clean`: %s", str(error)) return None try: LOG.info("%s", description) return subprocess.run( ["pyre", *arguments], stdout=subprocess.PIPE, stderr=stderr_flag, text=True, input=command_input, ) except subprocess.CalledProcessError as error: LOG.warning("Error calling pyre: %s", str(error)) return None def get_errors( self, only_fix_error_code: Optional[int] = None, should_clean: bool = True, command_input: Optional[str] = None, strict: bool = False, ) -> Errors: local_root_arguments = ( ["--local-configuration", self.root] if self.is_local else [] ) strict_arguments = ["--strict"] if strict else [] arguments = [*strict_arguments, *local_root_arguments, "--output=json", "check"] pyre_output = self.run_pyre( arguments=arguments, description=f"Checking `{self.root}`...", should_clean=self.targets is not None and should_clean, command_input=command_input, ) if not pyre_output: return Errors.empty() stdout = pyre_output.stdout if stdout is None: return Errors.empty() stdout = stdout.strip() try: errors = Errors.from_json(stdout, only_fix_error_code) except UserError as error: LOG.info("Error when parsing Pyre error output.") LOG.info(f"Pyre stdout: {stdout}\nPyre stderr: {pyre_output.stderr}") raise error LOG.info("Found %d error%s.", len(errors), "s" if len(errors) != 1 else "") return errors