client/commands/infer.py (812 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib import dataclasses import functools import itertools import json import logging import multiprocessing import re import shutil import subprocess import sys import traceback from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union import dataclasses_json import libcst from libcst.codemod import CodemodContext from .. import command_arguments, configuration as configuration_module, log from ..libcst_vendored_visitors import ApplyTypeAnnotationsVisitor from . import commands, remote_logging, backend_arguments, start LOG: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass(frozen=True) class Arguments: """ Data structure for configuration options the backend infer command can recognize. Need to keep in sync with `source/command/inferCommand.ml` """ base_arguments: backend_arguments.BaseArguments ignore_infer: Sequence[str] = dataclasses.field(default_factory=list) paths_to_modify: Optional[Set[Path]] = None def serialize(self) -> Dict[str, Any]: return { **self.base_arguments.serialize(), "ignore_infer": self.ignore_infer, **( {} if self.paths_to_modify is None else {"paths_to_modify": [str(path) for path in self.paths_to_modify]} ), } @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawAnnotationLocation: qualifier: str path: str line: int @dataclasses.dataclass(frozen=True) class RawAnnotation: name: str location: RawAnnotationLocation @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawGlobalAnnotation(RawAnnotation): annotation: str @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawAttributeAnnotation(RawAnnotation): parent: str annotation: str @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawParameter: name: str index: int annotation: Optional[str] = None value: Optional[str] = None @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawDefineAnnotation(RawAnnotation): parent: Optional[str] = None return_: Optional[str] = dataclasses.field( metadata=dataclasses_json.config(field_name="return"), default=None ) parameters: List[RawParameter] = dataclasses.field(default_factory=list) is_async: bool = dataclasses.field( metadata=dataclasses_json.config(field_name="async"), default=False ) TAnnotation = TypeVar("TAnnotation", bound=RawAnnotation) @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawInferOutput: global_annotations: List[RawGlobalAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="globals"), default_factory=list ) attribute_annotations: List[RawAttributeAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="attributes"), default_factory=list ) define_annotations: List[RawDefineAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="defines"), default_factory=list ) class ParsingError(Exception): pass @staticmethod def create_from_string(input: str) -> "RawInferOutput": try: # pyre-fixme[16]: Pyre doesn't understand `dataclasses_json` return RawInferOutput.schema().loads(input) except ( TypeError, KeyError, ValueError, dataclasses_json.mm.ValidationError, ) as error: raise RawInferOutput.ParsingError(str(error)) from error @staticmethod def create_from_json(input: Dict[str, object]) -> "RawInferOutput": return RawInferOutput.create_from_string(json.dumps(input)) def qualifiers_by_path(self) -> Dict[str, str]: return { annotation.location.path: annotation.location.qualifier for annotation in itertools.chain( self.global_annotations, self.attribute_annotations, self.define_annotations, ) } def split_by_path(self) -> "Dict[str, RawInferOutputForPath]": def create_index( annotations: Sequence[TAnnotation], ) -> Dict[str, List[TAnnotation]]: result: Dict[str, List[TAnnotation]] = {} for annotation in annotations: key = annotation.location.path result.setdefault(key, []).append(annotation) return result qualifiers_by_path = self.qualifiers_by_path() global_annotation_index = create_index(self.global_annotations) attribute_annotation_index = create_index(self.attribute_annotations) define_annotation_index = create_index(self.define_annotations) return { path: RawInferOutputForPath( global_annotations=global_annotation_index.get(path, []), attribute_annotations=attribute_annotation_index.get(path, []), define_annotations=define_annotation_index.get(path, []), qualifier=qualifiers_by_path[path], ) for path in global_annotation_index.keys() | attribute_annotation_index.keys() | define_annotation_index.keys() } @dataclasses_json.dataclass_json( letter_case=dataclasses_json.LetterCase.CAMEL, undefined=dataclasses_json.Undefined.EXCLUDE, ) @dataclasses.dataclass(frozen=True) class RawInferOutputForPath: qualifier: str global_annotations: List[RawGlobalAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="globals"), default_factory=list ) attribute_annotations: List[RawAttributeAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="attributes"), default_factory=list ) define_annotations: List[RawDefineAnnotation] = dataclasses.field( metadata=dataclasses_json.config(field_name="defines"), default_factory=list ) @staticmethod def create_from_json(input: Dict[str, object]) -> "RawInferOutputForPath": # pyre-fixme[16]: Pyre doesn't understand `dataclasses_json` return RawInferOutputForPath.schema().loads(json.dumps(input)) def _sanitize_name(name: str) -> str: """The last part of the access path is the function/attribute name""" return name.split(".")[-1] @functools.lru_cache() def empty_module() -> libcst.Module: return libcst.parse_module("") def code_for_node(node: libcst.CSTNode) -> str: return empty_module().code_for_node(node) class AnnotationFixer(libcst.CSTTransformer): def __init__(self, qualifier: str) -> None: """ AnnotationFixer sanitizes annotations. There are two transformations we apply: (1) Strip any prefix matching `prefix` from names. This is because the pyre backend always uses fully-qualified names, but in our stubs we should not include the prefix for names coming from this module. (2) Convert `Pathlike` annotations, which come from pyre in a special-cased form that isn't a correct annotation, to a quoted `"os.Pathlike[_]"`. Note: we eventually will need to either have a proper protocol in the backend for generating python-readable types, or extend (2) to handle various other special cases where pyre outputs types that are invalid in python. """ super().__init__() self.qualifier = qualifier def leave_Attribute( self, original_node: libcst.Attribute, updated_node: libcst.Attribute, ) -> Union[libcst.Attribute, libcst.Name]: """ Note: in order to avoid complex reverse-name-matching, we're effectively operating at the top level of attributes, by using only the `original_node`. This means the transform we're performing cannot be done concurrently with a transform that has to be done incrementally. """ value = code_for_node(original_node.value) if value == self.qualifier and libcst.matchers.matches( original_node.attr, libcst.matchers.Name() ): return libcst.ensure_type(original_node.attr, libcst.Name) return original_node def leave_Subscript( self, original_node: libcst.Subscript, updated_node: Union[libcst.Subscript, libcst.SimpleString], ) -> Union[libcst.Subscript, libcst.SimpleString]: if libcst.matchers.matches( original_node.value, libcst.matchers.Name("PathLike") ): name_node = libcst.Attribute( value=libcst.Name( value="os", lpar=[], rpar=[], ), attr=libcst.Name(value="PathLike"), ) node_as_string = code_for_node(updated_node.with_changes(value=name_node)) updated_node = libcst.SimpleString(f"'{node_as_string}'") return updated_node @staticmethod def sanitize( annotation: str, qualifier: str, quote_annotations: bool = False, dequalify_all: bool = False, runtime_defined: bool = True, ) -> str: """ Transform raw annotations in an attempt to reduce incorrectly-imported annotations in generated code. TODO(T93381000): Handle qualification in a principled way and remove this: all of these transforms are attempts to hack simple fixes to the problem of us not actually understanding qualified types and existing imports. (1) If `quote_annotations` is set to True, then spit out a quoted raw annotation (with fully-qualified names). The resulting generated code will not technically be PEP 484 compliant because it will use fully qualified type names without import, but pyre can understand this and it's useful for pysa to avoid adding import lines that change traces. (2) If `dequalify_all` is set, then remove all qualifiers from any top-level type (by top-level I mean outside the outermost brackets). For example, convert `sqlalchemy.sql.schema.Column[typing.Optional[int]]` into `Column[typing.Optional[int]]`. (3) Fix PathLike annotations: convert all bare `PathLike` uses to `'os.PathLike'`; the ocaml side of pyre currently spits out an unqualified type here which is incorrect, and quoting makes the use of os safer given that we don't handle imports correctly yet. """ if quote_annotations: return f'"{annotation}"' if dequalify_all: match = re.fullmatch(r"([^.]*?\.)*?([^.]+)(\[.*\])", annotation) if match is not None: annotation = f"{match.group(2)}{match.group(3)}" try: tree = libcst.parse_module(annotation) annotation = tree.visit( AnnotationFixer( qualifier=qualifier, ) ).code except libcst._exceptions.ParserSyntaxError: pass if not runtime_defined: return f'"{annotation}"' return annotation @dataclasses.dataclass(frozen=True) class StubGenerationOptions: annotate_attributes: bool = False use_future_annotations: bool = False quote_annotations: bool = False simple_annotations: bool = False dequalify: bool = False debug_infer: bool = False def __post__init__(self) -> None: if self.quote_annotations and (self.use_future_annotations or self.dequalify): raise ValueError( "You should not mix the `quote_annotations` option, which causes pyre " "to generate quoted and qualified annotations, with the " "`use_future_annotations` or `dequalify` options." ) @dataclasses.dataclass(frozen=True) class TypeAnnotation: annotation: Optional[str] qualifier: str options: StubGenerationOptions runtime_defined: bool @staticmethod def from_raw( annotation: Optional[str], options: StubGenerationOptions, qualifier: str, runtime_defined: bool = True, ) -> "TypeAnnotation": return TypeAnnotation( annotation=annotation, qualifier=qualifier, options=options, runtime_defined=runtime_defined, ) @staticmethod def is_simple(sanitized_annotation: str) -> bool: """ This is a flexible definition that should expand if/when our ability to handle annotations or imports without manual adjustment improves. """ return len(sanitized_annotation.split(".")) == 1 def to_stub(self, prefix: str = "") -> str: if self.annotation is None: return "" else: sanitized = AnnotationFixer.sanitize( annotation=self.annotation, qualifier=self.qualifier, quote_annotations=self.options.quote_annotations, dequalify_all=self.options.dequalify, runtime_defined=self.runtime_defined, ) if self.options.simple_annotations and not TypeAnnotation.is_simple( sanitized ): return "" return prefix + sanitized @property def missing(self) -> bool: return self.annotation is None @dataclasses.dataclass(frozen=True) class Parameter: name: str annotation: TypeAnnotation value: Optional[str] def to_stub(self) -> str: delimiter = "=" if self.annotation.missing else " = " value = f"{delimiter}{self.value}" if self.value else "" return f"{self.name}{self.annotation.to_stub(prefix=': ')}{value}" @dataclasses.dataclass(frozen=True) class FunctionAnnotation: name: str return_annotation: TypeAnnotation parameters: Sequence[Parameter] is_async: bool def to_stub(self) -> str: name = _sanitize_name(self.name) async_ = "async " if self.is_async else "" parameters = ", ".join(parameter.to_stub() for parameter in self.parameters) return_ = self.return_annotation.to_stub(prefix=" -> ") return f"{async_}def {name}({parameters}){return_}: ..." @dataclasses.dataclass(frozen=True) class MethodAnnotation(FunctionAnnotation): parent: str @dataclasses.dataclass(frozen=True) class FieldAnnotation: name: str annotation: TypeAnnotation def __post_init__(self) -> None: if self.annotation.missing: raise RuntimeError(f"Illegal missing FieldAnnotation for {self.name}") def to_stub(self) -> str: name = _sanitize_name(self.name) return f"{name}: {self.annotation.to_stub()} = ..." @dataclasses.dataclass(frozen=True) class GlobalAnnotation(FieldAnnotation): pass @dataclasses.dataclass(frozen=True) class AttributeAnnotation(FieldAnnotation): parent: str @dataclasses.dataclass(frozen=True) class ModuleAnnotations: path: str options: StubGenerationOptions globals_: List[GlobalAnnotation] = dataclasses.field(default_factory=list) attributes: List[AttributeAnnotation] = dataclasses.field(default_factory=list) functions: List[FunctionAnnotation] = dataclasses.field(default_factory=list) methods: List[MethodAnnotation] = dataclasses.field(default_factory=list) @staticmethod def from_infer_output( path: str, infer_output: RawInferOutputForPath, options: StubGenerationOptions, ) -> "ModuleAnnotations": def type_annotation( annotation: Optional[str], parent_class: Optional[str] = None ) -> TypeAnnotation: return TypeAnnotation.from_raw( annotation, qualifier=infer_output.qualifier, options=options, runtime_defined=parent_class != annotation if parent_class else True, ) return ModuleAnnotations( path=path, globals_=[ GlobalAnnotation( name=global_.name, annotation=type_annotation(global_.annotation) ) for global_ in infer_output.global_annotations ], attributes=[ AttributeAnnotation( parent=attribute.parent, name=attribute.name, annotation=type_annotation(attribute.annotation, attribute.parent), ) for attribute in infer_output.attribute_annotations ] if options.annotate_attributes else [], functions=[ FunctionAnnotation( name=define.name, return_annotation=type_annotation(define.return_), parameters=[ Parameter( name=parameter.name, annotation=type_annotation(parameter.annotation), value=parameter.value, ) for parameter in define.parameters ], is_async=define.is_async, ) for define in infer_output.define_annotations if define.parent is None ], methods=[ MethodAnnotation( parent=define.parent, name=define.name, return_annotation=type_annotation(define.return_, define.parent), parameters=[ Parameter( name=parameter.name, annotation=type_annotation( parameter.annotation, define.parent ), value=parameter.value, ) for parameter in define.parameters ], is_async=define.is_async, ) for define in infer_output.define_annotations if define.parent is not None ], options=options, ) def is_empty(self) -> bool: return ( len(self.globals_) + len(self.attributes) + len(self.functions) + len(self.methods) ) == 0 @staticmethod def _indent(stub: str) -> str: return " " + stub.replace("\n", "\n ") def _relativize(self, parent: str) -> Sequence[str]: path = ( str(self.path).split(".", 1)[0].replace("/", ".").replace(".__init__", "") ) return parent.replace(path, "", 1).strip(".").split(".") @property def classes(self) -> Dict[str, List[Union[AttributeAnnotation, MethodAnnotation]]]: """ Find all classes with attributes or methods to annotate. Anything in nested classes is currently ignored, e.g.: ``` class X: class Y: [ALL OF THIS IS IGNORED] ``` """ classes: Dict[str, List[Union[AttributeAnnotation, MethodAnnotation]]] = {} nested_class_count = 0 for annotation in [*self.attributes, *self.methods]: parent = self._relativize(annotation.parent) if len(parent) == 1: classes.setdefault(parent[0], []).append(annotation) else: nested_class_count += 1 if nested_class_count > 0: LOG.warning( f"In file {self.path}, ignored {nested_class_count} nested classes" ) return classes def _class_stub( self, classname: str, annotations: Sequence[Union[AttributeAnnotation, MethodAnnotation]], ) -> str: body = "\n".join( self._indent(annotation.to_stub()) for annotation in annotations ) return f"class {classname}:\n{body}\n" def to_stubs(self) -> str: """ Output annotation information as a stub file. """ return "\n".join( [ *(global_.to_stub() for global_ in self.globals_), *(function.to_stub() for function in self.functions), *( self._class_stub(classname, annotations) for classname, annotations in self.classes.items() ), "", ] ) def stubs_path(self, directory: Path) -> Path: return (directory / self.path).with_suffix(".pyi") def write_stubs(self, type_directory: Path) -> None: path = self.stubs_path(type_directory) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(self.to_stubs()) @dataclasses.dataclass class AnnotateModuleInPlace: full_stub_path: str full_code_path: str options: StubGenerationOptions @staticmethod def _annotated_code( code_path: str, stub: str, code: str, options: StubGenerationOptions, ) -> Optional[str]: """ Merge inferred annotations from stubs with source code to get annotated code. """ if "@" "generated" in code: LOG.warning(f"Skipping generated file {code_path}") return context = CodemodContext() ApplyTypeAnnotationsVisitor.store_stub_in_context( context=context, stub=libcst.parse_module(stub), use_future_annotations=options.use_future_annotations, ) modified_tree = ApplyTypeAnnotationsVisitor(context).transform_module( libcst.parse_module(code) ) return modified_tree.code @staticmethod def annotate_code( stub_path: str, code_path: str, options: StubGenerationOptions, ) -> None: "Merge a stub file of inferred annotations with a code file in place." try: stub = Path(stub_path).read_text() code = Path(code_path).read_text() annotated_code = AnnotateModuleInPlace._annotated_code( code_path=code_path, stub=stub, code=code, options=options, ) if annotated_code is not None: Path(code_path).write_text(annotated_code) LOG.info(f"Annotated {code_path}") except Exception as error: LOG.warning(f"Failed to annotate {code_path}") if options.debug_infer: LOG.warning(f"\tError: {error}") def run(self) -> None: return self.annotate_code( stub_path=self.full_stub_path, code_path=self.full_code_path, options=self.options, ) @staticmethod def run_task(task: "AnnotateModuleInPlace") -> None: "Wrap `run` in a static method to use with multiprocessing" return task.run() def create_infer_arguments( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> Arguments: """ Translate client configurations to backend check configurations. This API is not pure since it needs to access filesystem to filter out nonexistent directories. It is idempotent though, since it does not alter any filesystem state. """ source_paths = backend_arguments.get_source_path_for_check(configuration) profiling_output = ( backend_arguments.get_profiling_log_path(Path(configuration.log_directory)) if infer_arguments.enable_profiling else None ) memory_profiling_output = ( backend_arguments.get_profiling_log_path(Path(configuration.log_directory)) if infer_arguments.enable_memory_profiling else None ) logger = configuration.logger remote_logging = ( backend_arguments.RemoteLogging( logger=logger, identifier=infer_arguments.log_identifier or "" ) if logger is not None else None ) return Arguments( base_arguments=backend_arguments.BaseArguments( log_path=configuration.log_directory, global_root=configuration.project_root, checked_directory_allowlist=backend_arguments.get_checked_directory_allowlist( configuration, source_paths ), checked_directory_blocklist=( configuration.get_existent_ignore_all_errors_paths() ), debug=infer_arguments.debug_infer, excludes=configuration.excludes, extensions=configuration.get_valid_extension_suffixes(), relative_local_root=configuration.relative_local_root, memory_profiling_output=memory_profiling_output, number_of_workers=configuration.get_number_of_workers(), parallel=not infer_arguments.sequential, profiling_output=profiling_output, python_version=configuration.get_python_version(), shared_memory=configuration.shared_memory, remote_logging=remote_logging, search_paths=configuration.expand_and_get_existent_search_paths(), source_paths=source_paths, ), ignore_infer=configuration.get_existent_ignore_infer_paths(), paths_to_modify=infer_arguments.paths_to_modify, ) @contextlib.contextmanager def create_infer_arguments_and_cleanup( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> Iterator[Arguments]: arguments = create_infer_arguments(configuration, infer_arguments) try: yield arguments finally: # It is safe to clean up source paths after infer command since # any created artifact directory won't be reused by other commands. arguments.base_arguments.source_paths.cleanup() def _check_working_directory( working_directory: Path, global_root: Path, relative_local_root: Optional[str] ) -> None: candidate_locations: List[str] = [] if working_directory == global_root: return candidate_locations.append(f"`{global_root}` with `--local-configuration` set") if relative_local_root is not None: local_root = global_root / relative_local_root if working_directory == local_root: return candidate_locations.append(f"`{local_root}`") valid_locations = " or from ".join(candidate_locations) raise ValueError( f"Infer must run from {valid_locations}. " f"Cannot run from current working directory `{working_directory}`." ) def _run_infer_command_get_output(command: Sequence[str]) -> str: with backend_arguments.backend_log_file(prefix="pyre_infer") as log_file: with start.background_logging(Path(log_file.name)): result = subprocess.run( command, stdout=subprocess.PIPE, stderr=log_file.file, universal_newlines=True, ) return_code = result.returncode # Interpretation of the return code needs to be kept in sync with # `source/command/inferCommand.ml`. if return_code == 0: return result.stdout elif return_code == 1: raise commands.ClientException( message="Pyre encountered an internal failure", exit_code=commands.ExitCode.FAILURE, ) elif return_code == 2: raise commands.ClientException( message="Pyre encountered a failure within buck.", exit_code=commands.ExitCode.BUCK_INTERNAL_ERROR, ) elif return_code == 3: raise commands.ClientException( message="Pyre encountered an error when building the buck targets.", exit_code=commands.ExitCode.BUCK_USER_ERROR, ) else: raise commands.ClientException( message=( "Infer command exited with unexpected return code: " f"{return_code}." ), exit_code=commands.ExitCode.FAILURE, ) def _get_infer_command_output( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> str: binary_location = configuration.get_binary_respecting_override() if binary_location is None: raise configuration_module.InvalidConfiguration( "Cannot locate a Pyre binary to run." ) with create_infer_arguments_and_cleanup( configuration, infer_arguments ) as arguments: with backend_arguments.temporary_argument_file(arguments) as argument_file_path: infer_command = [binary_location, "newinfer", str(argument_file_path)] return _run_infer_command_get_output(command=infer_command) def _load_output( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> str: if infer_arguments.read_stdin: return sys.stdin.read() else: return _get_infer_command_output(configuration, infer_arguments) def _relativize_path(path: str, against: Path) -> Optional[str]: given_path = Path(path) return ( None if against not in given_path.parents else str(given_path.relative_to(against)) ) def create_module_annotations( infer_output: RawInferOutput, base_path: Path, options: StubGenerationOptions ) -> List[ModuleAnnotations]: infer_output_relativized: Dict[Optional[str], RawInferOutputForPath] = { _relativize_path(path, against=base_path): data for path, data in infer_output.split_by_path().items() } return [ ModuleAnnotations.from_infer_output( path=path, infer_output=infer_output_for_path, options=options, ) for path, infer_output_for_path in infer_output_relativized.items() if path is not None ] def _print_inferences( infer_output: RawInferOutput, module_annotations: Sequence[ModuleAnnotations] ) -> None: LOG.log(log.SUCCESS, "Raw Infer Outputs:") # pyre-ignore[16]: Pyre does not understand `dataclasses_json` LOG.log(log.SUCCESS, json.dumps(infer_output.to_dict(), indent=2)) LOG.log(log.SUCCESS, "Generated Stubs:") LOG.log( log.SUCCESS, "\n\n".join( f"*{module.path}*\n{module.to_stubs()}" for module in module_annotations ), ) def _get_type_directory(log_directory: Path) -> Path: return log_directory / "types" def _write_stubs( type_directory: Path, module_annotations: Sequence[ModuleAnnotations] ) -> None: if type_directory.exists(): LOG.log(log.SUCCESS, f"Deleting {type_directory}") shutil.rmtree(type_directory) type_directory.mkdir(parents=True, exist_ok=True) LOG.log(log.SUCCESS, f"Outputting inferred stubs to {type_directory}...") for module in module_annotations: module.write_stubs(type_directory=type_directory) def should_annotate_in_place( path: Path, paths_to_modify: Optional[Set[Path]], ) -> bool: return ( True if paths_to_modify is None else any(path in paths_to_modify for path in (path, *path.parents)) ) def _annotate_in_place( working_directory: Path, type_directory: Path, paths_to_modify: Optional[Set[Path]], options: StubGenerationOptions, number_of_workers: int, ) -> None: tasks: List[AnnotateModuleInPlace] = [] for full_stub_path in type_directory.rglob("*.pyi"): relative_stub_path = full_stub_path.relative_to(type_directory) relative_code_path = relative_stub_path.with_suffix(".py") full_code_path = working_directory / relative_code_path if should_annotate_in_place(full_code_path, paths_to_modify): tasks.append( AnnotateModuleInPlace( full_stub_path=str(full_stub_path), full_code_path=str(full_code_path), options=options, ) ) with multiprocessing.Pool(number_of_workers) as pool: for _ in pool.imap_unordered(AnnotateModuleInPlace.run_task, tasks): pass def run_infer( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> commands.ExitCode: working_directory = infer_arguments.working_directory _check_working_directory( working_directory=working_directory, global_root=Path(configuration.project_root), relative_local_root=configuration.relative_local_root, ) type_directory = _get_type_directory(Path(configuration.log_directory)) in_place = infer_arguments.in_place options = StubGenerationOptions( annotate_attributes=infer_arguments.annotate_attributes, use_future_annotations=infer_arguments.use_future_annotations, dequalify=infer_arguments.dequalify, quote_annotations=infer_arguments.quote_annotations, simple_annotations=infer_arguments.simple_annotations, ) if infer_arguments.annotate_from_existing_stubs: if not in_place: raise ValueError( "`--annotate-from-existing-stubs` cannot be used without the" " `--in-place` flag" ) _annotate_in_place( working_directory=working_directory, type_directory=type_directory, paths_to_modify=infer_arguments.paths_to_modify, options=options, number_of_workers=configuration.get_number_of_workers(), ) else: infer_output = RawInferOutput.create_from_json( json.loads(_load_output(configuration, infer_arguments))[0] ) module_annotations = create_module_annotations( infer_output=infer_output, base_path=working_directory, options=options, ) if infer_arguments.print_only: _print_inferences(infer_output, module_annotations) else: _write_stubs(type_directory, module_annotations) if in_place: _annotate_in_place( working_directory=working_directory, type_directory=type_directory, paths_to_modify=infer_arguments.paths_to_modify, options=options, number_of_workers=configuration.get_number_of_workers(), ) return commands.ExitCode.SUCCESS @remote_logging.log_usage(command_name="infer") def run( configuration: configuration_module.Configuration, infer_arguments: command_arguments.InferArguments, ) -> commands.ExitCode: try: return run_infer(configuration, infer_arguments) except commands.ClientException: raise except Exception as error: traceback.print_exc(file=sys.stderr) raise commands.ClientException( f"Exception occurred during Pyre infer: {error}" ) from error