client/commands/infer.py (812 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import dataclasses
import functools
import itertools
import json
import logging
import multiprocessing
import re
import shutil
import subprocess
import sys
import traceback
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
import dataclasses_json
import libcst
from libcst.codemod import CodemodContext
from .. import command_arguments, configuration as configuration_module, log
from ..libcst_vendored_visitors import ApplyTypeAnnotationsVisitor
from . import commands, remote_logging, backend_arguments, start
LOG: logging.Logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class Arguments:
"""
Data structure for configuration options the backend infer command can recognize.
Need to keep in sync with `source/command/inferCommand.ml`
"""
base_arguments: backend_arguments.BaseArguments
ignore_infer: Sequence[str] = dataclasses.field(default_factory=list)
paths_to_modify: Optional[Set[Path]] = None
def serialize(self) -> Dict[str, Any]:
return {
**self.base_arguments.serialize(),
"ignore_infer": self.ignore_infer,
**(
{}
if self.paths_to_modify is None
else {"paths_to_modify": [str(path) for path in self.paths_to_modify]}
),
}
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawAnnotationLocation:
qualifier: str
path: str
line: int
@dataclasses.dataclass(frozen=True)
class RawAnnotation:
name: str
location: RawAnnotationLocation
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawGlobalAnnotation(RawAnnotation):
annotation: str
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawAttributeAnnotation(RawAnnotation):
parent: str
annotation: str
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawParameter:
name: str
index: int
annotation: Optional[str] = None
value: Optional[str] = None
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawDefineAnnotation(RawAnnotation):
parent: Optional[str] = None
return_: Optional[str] = dataclasses.field(
metadata=dataclasses_json.config(field_name="return"), default=None
)
parameters: List[RawParameter] = dataclasses.field(default_factory=list)
is_async: bool = dataclasses.field(
metadata=dataclasses_json.config(field_name="async"), default=False
)
TAnnotation = TypeVar("TAnnotation", bound=RawAnnotation)
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawInferOutput:
global_annotations: List[RawGlobalAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="globals"), default_factory=list
)
attribute_annotations: List[RawAttributeAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="attributes"), default_factory=list
)
define_annotations: List[RawDefineAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="defines"), default_factory=list
)
class ParsingError(Exception):
pass
@staticmethod
def create_from_string(input: str) -> "RawInferOutput":
try:
# pyre-fixme[16]: Pyre doesn't understand `dataclasses_json`
return RawInferOutput.schema().loads(input)
except (
TypeError,
KeyError,
ValueError,
dataclasses_json.mm.ValidationError,
) as error:
raise RawInferOutput.ParsingError(str(error)) from error
@staticmethod
def create_from_json(input: Dict[str, object]) -> "RawInferOutput":
return RawInferOutput.create_from_string(json.dumps(input))
def qualifiers_by_path(self) -> Dict[str, str]:
return {
annotation.location.path: annotation.location.qualifier
for annotation in itertools.chain(
self.global_annotations,
self.attribute_annotations,
self.define_annotations,
)
}
def split_by_path(self) -> "Dict[str, RawInferOutputForPath]":
def create_index(
annotations: Sequence[TAnnotation],
) -> Dict[str, List[TAnnotation]]:
result: Dict[str, List[TAnnotation]] = {}
for annotation in annotations:
key = annotation.location.path
result.setdefault(key, []).append(annotation)
return result
qualifiers_by_path = self.qualifiers_by_path()
global_annotation_index = create_index(self.global_annotations)
attribute_annotation_index = create_index(self.attribute_annotations)
define_annotation_index = create_index(self.define_annotations)
return {
path: RawInferOutputForPath(
global_annotations=global_annotation_index.get(path, []),
attribute_annotations=attribute_annotation_index.get(path, []),
define_annotations=define_annotation_index.get(path, []),
qualifier=qualifiers_by_path[path],
)
for path in global_annotation_index.keys()
| attribute_annotation_index.keys()
| define_annotation_index.keys()
}
@dataclasses_json.dataclass_json(
letter_case=dataclasses_json.LetterCase.CAMEL,
undefined=dataclasses_json.Undefined.EXCLUDE,
)
@dataclasses.dataclass(frozen=True)
class RawInferOutputForPath:
qualifier: str
global_annotations: List[RawGlobalAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="globals"), default_factory=list
)
attribute_annotations: List[RawAttributeAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="attributes"), default_factory=list
)
define_annotations: List[RawDefineAnnotation] = dataclasses.field(
metadata=dataclasses_json.config(field_name="defines"), default_factory=list
)
@staticmethod
def create_from_json(input: Dict[str, object]) -> "RawInferOutputForPath":
# pyre-fixme[16]: Pyre doesn't understand `dataclasses_json`
return RawInferOutputForPath.schema().loads(json.dumps(input))
def _sanitize_name(name: str) -> str:
"""The last part of the access path is the function/attribute name"""
return name.split(".")[-1]
@functools.lru_cache()
def empty_module() -> libcst.Module:
return libcst.parse_module("")
def code_for_node(node: libcst.CSTNode) -> str:
return empty_module().code_for_node(node)
class AnnotationFixer(libcst.CSTTransformer):
def __init__(self, qualifier: str) -> None:
"""
AnnotationFixer sanitizes annotations.
There are two transformations we apply:
(1) Strip any prefix matching `prefix` from names. This is because
the pyre backend always uses fully-qualified names, but in our stubs
we should not include the prefix for names coming from this module.
(2) Convert `Pathlike` annotations, which come from pyre in a
special-cased form that isn't a correct annotation, to a quoted
`"os.Pathlike[_]"`.
Note: we eventually will need to either have a proper protocol in the
backend for generating python-readable types, or extend (2) to handle
various other special cases where pyre outputs types that are invalid
in python.
"""
super().__init__()
self.qualifier = qualifier
def leave_Attribute(
self,
original_node: libcst.Attribute,
updated_node: libcst.Attribute,
) -> Union[libcst.Attribute, libcst.Name]:
"""
Note: in order to avoid complex reverse-name-matching, we're
effectively operating at the top level of attributes, by using only
the `original_node`. This means the transform we're performing cannot
be done concurrently with a transform that has to be done
incrementally.
"""
value = code_for_node(original_node.value)
if value == self.qualifier and libcst.matchers.matches(
original_node.attr, libcst.matchers.Name()
):
return libcst.ensure_type(original_node.attr, libcst.Name)
return original_node
def leave_Subscript(
self,
original_node: libcst.Subscript,
updated_node: Union[libcst.Subscript, libcst.SimpleString],
) -> Union[libcst.Subscript, libcst.SimpleString]:
if libcst.matchers.matches(
original_node.value, libcst.matchers.Name("PathLike")
):
name_node = libcst.Attribute(
value=libcst.Name(
value="os",
lpar=[],
rpar=[],
),
attr=libcst.Name(value="PathLike"),
)
node_as_string = code_for_node(updated_node.with_changes(value=name_node))
updated_node = libcst.SimpleString(f"'{node_as_string}'")
return updated_node
@staticmethod
def sanitize(
annotation: str,
qualifier: str,
quote_annotations: bool = False,
dequalify_all: bool = False,
runtime_defined: bool = True,
) -> str:
"""
Transform raw annotations in an attempt to reduce incorrectly-imported
annotations in generated code.
TODO(T93381000): Handle qualification in a principled way and remove
this: all of these transforms are attempts to hack simple fixes to the
problem of us not actually understanding qualified types and existing
imports.
(1) If `quote_annotations` is set to True, then spit out a quoted
raw annotation (with fully-qualified names). The resulting generated
code will not technically be PEP 484 compliant because it will use
fully qualified type names without import, but pyre can understand
this and it's useful for pysa to avoid adding import lines that
change traces.
(2) If `dequalify_all` is set, then remove all qualifiers from any
top-level type (by top-level I mean outside the outermost brackets). For
example, convert `sqlalchemy.sql.schema.Column[typing.Optional[int]]`
into `Column[typing.Optional[int]]`.
(3) Fix PathLike annotations: convert all bare `PathLike` uses to
`'os.PathLike'`; the ocaml side of pyre currently spits out an unqualified
type here which is incorrect, and quoting makes the use of os safer
given that we don't handle imports correctly yet.
"""
if quote_annotations:
return f'"{annotation}"'
if dequalify_all:
match = re.fullmatch(r"([^.]*?\.)*?([^.]+)(\[.*\])", annotation)
if match is not None:
annotation = f"{match.group(2)}{match.group(3)}"
try:
tree = libcst.parse_module(annotation)
annotation = tree.visit(
AnnotationFixer(
qualifier=qualifier,
)
).code
except libcst._exceptions.ParserSyntaxError:
pass
if not runtime_defined:
return f'"{annotation}"'
return annotation
@dataclasses.dataclass(frozen=True)
class StubGenerationOptions:
annotate_attributes: bool = False
use_future_annotations: bool = False
quote_annotations: bool = False
simple_annotations: bool = False
dequalify: bool = False
debug_infer: bool = False
def __post__init__(self) -> None:
if self.quote_annotations and (self.use_future_annotations or self.dequalify):
raise ValueError(
"You should not mix the `quote_annotations` option, which causes pyre "
"to generate quoted and qualified annotations, with the "
"`use_future_annotations` or `dequalify` options."
)
@dataclasses.dataclass(frozen=True)
class TypeAnnotation:
annotation: Optional[str]
qualifier: str
options: StubGenerationOptions
runtime_defined: bool
@staticmethod
def from_raw(
annotation: Optional[str],
options: StubGenerationOptions,
qualifier: str,
runtime_defined: bool = True,
) -> "TypeAnnotation":
return TypeAnnotation(
annotation=annotation,
qualifier=qualifier,
options=options,
runtime_defined=runtime_defined,
)
@staticmethod
def is_simple(sanitized_annotation: str) -> bool:
"""
This is a flexible definition that should expand if/when our ability to
handle annotations or imports without manual adjustment improves.
"""
return len(sanitized_annotation.split(".")) == 1
def to_stub(self, prefix: str = "") -> str:
if self.annotation is None:
return ""
else:
sanitized = AnnotationFixer.sanitize(
annotation=self.annotation,
qualifier=self.qualifier,
quote_annotations=self.options.quote_annotations,
dequalify_all=self.options.dequalify,
runtime_defined=self.runtime_defined,
)
if self.options.simple_annotations and not TypeAnnotation.is_simple(
sanitized
):
return ""
return prefix + sanitized
@property
def missing(self) -> bool:
return self.annotation is None
@dataclasses.dataclass(frozen=True)
class Parameter:
name: str
annotation: TypeAnnotation
value: Optional[str]
def to_stub(self) -> str:
delimiter = "=" if self.annotation.missing else " = "
value = f"{delimiter}{self.value}" if self.value else ""
return f"{self.name}{self.annotation.to_stub(prefix=': ')}{value}"
@dataclasses.dataclass(frozen=True)
class FunctionAnnotation:
name: str
return_annotation: TypeAnnotation
parameters: Sequence[Parameter]
is_async: bool
def to_stub(self) -> str:
name = _sanitize_name(self.name)
async_ = "async " if self.is_async else ""
parameters = ", ".join(parameter.to_stub() for parameter in self.parameters)
return_ = self.return_annotation.to_stub(prefix=" -> ")
return f"{async_}def {name}({parameters}){return_}: ..."
@dataclasses.dataclass(frozen=True)
class MethodAnnotation(FunctionAnnotation):
parent: str
@dataclasses.dataclass(frozen=True)
class FieldAnnotation:
name: str
annotation: TypeAnnotation
def __post_init__(self) -> None:
if self.annotation.missing:
raise RuntimeError(f"Illegal missing FieldAnnotation for {self.name}")
def to_stub(self) -> str:
name = _sanitize_name(self.name)
return f"{name}: {self.annotation.to_stub()} = ..."
@dataclasses.dataclass(frozen=True)
class GlobalAnnotation(FieldAnnotation):
pass
@dataclasses.dataclass(frozen=True)
class AttributeAnnotation(FieldAnnotation):
parent: str
@dataclasses.dataclass(frozen=True)
class ModuleAnnotations:
path: str
options: StubGenerationOptions
globals_: List[GlobalAnnotation] = dataclasses.field(default_factory=list)
attributes: List[AttributeAnnotation] = dataclasses.field(default_factory=list)
functions: List[FunctionAnnotation] = dataclasses.field(default_factory=list)
methods: List[MethodAnnotation] = dataclasses.field(default_factory=list)
@staticmethod
def from_infer_output(
path: str,
infer_output: RawInferOutputForPath,
options: StubGenerationOptions,
) -> "ModuleAnnotations":
def type_annotation(
annotation: Optional[str], parent_class: Optional[str] = None
) -> TypeAnnotation:
return TypeAnnotation.from_raw(
annotation,
qualifier=infer_output.qualifier,
options=options,
runtime_defined=parent_class != annotation if parent_class else True,
)
return ModuleAnnotations(
path=path,
globals_=[
GlobalAnnotation(
name=global_.name, annotation=type_annotation(global_.annotation)
)
for global_ in infer_output.global_annotations
],
attributes=[
AttributeAnnotation(
parent=attribute.parent,
name=attribute.name,
annotation=type_annotation(attribute.annotation, attribute.parent),
)
for attribute in infer_output.attribute_annotations
]
if options.annotate_attributes
else [],
functions=[
FunctionAnnotation(
name=define.name,
return_annotation=type_annotation(define.return_),
parameters=[
Parameter(
name=parameter.name,
annotation=type_annotation(parameter.annotation),
value=parameter.value,
)
for parameter in define.parameters
],
is_async=define.is_async,
)
for define in infer_output.define_annotations
if define.parent is None
],
methods=[
MethodAnnotation(
parent=define.parent,
name=define.name,
return_annotation=type_annotation(define.return_, define.parent),
parameters=[
Parameter(
name=parameter.name,
annotation=type_annotation(
parameter.annotation, define.parent
),
value=parameter.value,
)
for parameter in define.parameters
],
is_async=define.is_async,
)
for define in infer_output.define_annotations
if define.parent is not None
],
options=options,
)
def is_empty(self) -> bool:
return (
len(self.globals_)
+ len(self.attributes)
+ len(self.functions)
+ len(self.methods)
) == 0
@staticmethod
def _indent(stub: str) -> str:
return " " + stub.replace("\n", "\n ")
def _relativize(self, parent: str) -> Sequence[str]:
path = (
str(self.path).split(".", 1)[0].replace("/", ".").replace(".__init__", "")
)
return parent.replace(path, "", 1).strip(".").split(".")
@property
def classes(self) -> Dict[str, List[Union[AttributeAnnotation, MethodAnnotation]]]:
"""
Find all classes with attributes or methods to annotate.
Anything in nested classes is currently ignored, e.g.:
```
class X:
class Y:
[ALL OF THIS IS IGNORED]
```
"""
classes: Dict[str, List[Union[AttributeAnnotation, MethodAnnotation]]] = {}
nested_class_count = 0
for annotation in [*self.attributes, *self.methods]:
parent = self._relativize(annotation.parent)
if len(parent) == 1:
classes.setdefault(parent[0], []).append(annotation)
else:
nested_class_count += 1
if nested_class_count > 0:
LOG.warning(
f"In file {self.path}, ignored {nested_class_count} nested classes"
)
return classes
def _class_stub(
self,
classname: str,
annotations: Sequence[Union[AttributeAnnotation, MethodAnnotation]],
) -> str:
body = "\n".join(
self._indent(annotation.to_stub()) for annotation in annotations
)
return f"class {classname}:\n{body}\n"
def to_stubs(self) -> str:
"""
Output annotation information as a stub file.
"""
return "\n".join(
[
*(global_.to_stub() for global_ in self.globals_),
*(function.to_stub() for function in self.functions),
*(
self._class_stub(classname, annotations)
for classname, annotations in self.classes.items()
),
"",
]
)
def stubs_path(self, directory: Path) -> Path:
return (directory / self.path).with_suffix(".pyi")
def write_stubs(self, type_directory: Path) -> None:
path = self.stubs_path(type_directory)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(self.to_stubs())
@dataclasses.dataclass
class AnnotateModuleInPlace:
full_stub_path: str
full_code_path: str
options: StubGenerationOptions
@staticmethod
def _annotated_code(
code_path: str,
stub: str,
code: str,
options: StubGenerationOptions,
) -> Optional[str]:
"""
Merge inferred annotations from stubs with source code to get
annotated code.
"""
if "@" "generated" in code:
LOG.warning(f"Skipping generated file {code_path}")
return
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context=context,
stub=libcst.parse_module(stub),
use_future_annotations=options.use_future_annotations,
)
modified_tree = ApplyTypeAnnotationsVisitor(context).transform_module(
libcst.parse_module(code)
)
return modified_tree.code
@staticmethod
def annotate_code(
stub_path: str,
code_path: str,
options: StubGenerationOptions,
) -> None:
"Merge a stub file of inferred annotations with a code file in place."
try:
stub = Path(stub_path).read_text()
code = Path(code_path).read_text()
annotated_code = AnnotateModuleInPlace._annotated_code(
code_path=code_path,
stub=stub,
code=code,
options=options,
)
if annotated_code is not None:
Path(code_path).write_text(annotated_code)
LOG.info(f"Annotated {code_path}")
except Exception as error:
LOG.warning(f"Failed to annotate {code_path}")
if options.debug_infer:
LOG.warning(f"\tError: {error}")
def run(self) -> None:
return self.annotate_code(
stub_path=self.full_stub_path,
code_path=self.full_code_path,
options=self.options,
)
@staticmethod
def run_task(task: "AnnotateModuleInPlace") -> None:
"Wrap `run` in a static method to use with multiprocessing"
return task.run()
def create_infer_arguments(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> Arguments:
"""
Translate client configurations to backend check configurations.
This API is not pure since it needs to access filesystem to filter out
nonexistent directories. It is idempotent though, since it does not alter
any filesystem state.
"""
source_paths = backend_arguments.get_source_path_for_check(configuration)
profiling_output = (
backend_arguments.get_profiling_log_path(Path(configuration.log_directory))
if infer_arguments.enable_profiling
else None
)
memory_profiling_output = (
backend_arguments.get_profiling_log_path(Path(configuration.log_directory))
if infer_arguments.enable_memory_profiling
else None
)
logger = configuration.logger
remote_logging = (
backend_arguments.RemoteLogging(
logger=logger, identifier=infer_arguments.log_identifier or ""
)
if logger is not None
else None
)
return Arguments(
base_arguments=backend_arguments.BaseArguments(
log_path=configuration.log_directory,
global_root=configuration.project_root,
checked_directory_allowlist=backend_arguments.get_checked_directory_allowlist(
configuration, source_paths
),
checked_directory_blocklist=(
configuration.get_existent_ignore_all_errors_paths()
),
debug=infer_arguments.debug_infer,
excludes=configuration.excludes,
extensions=configuration.get_valid_extension_suffixes(),
relative_local_root=configuration.relative_local_root,
memory_profiling_output=memory_profiling_output,
number_of_workers=configuration.get_number_of_workers(),
parallel=not infer_arguments.sequential,
profiling_output=profiling_output,
python_version=configuration.get_python_version(),
shared_memory=configuration.shared_memory,
remote_logging=remote_logging,
search_paths=configuration.expand_and_get_existent_search_paths(),
source_paths=source_paths,
),
ignore_infer=configuration.get_existent_ignore_infer_paths(),
paths_to_modify=infer_arguments.paths_to_modify,
)
@contextlib.contextmanager
def create_infer_arguments_and_cleanup(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> Iterator[Arguments]:
arguments = create_infer_arguments(configuration, infer_arguments)
try:
yield arguments
finally:
# It is safe to clean up source paths after infer command since
# any created artifact directory won't be reused by other commands.
arguments.base_arguments.source_paths.cleanup()
def _check_working_directory(
working_directory: Path, global_root: Path, relative_local_root: Optional[str]
) -> None:
candidate_locations: List[str] = []
if working_directory == global_root:
return
candidate_locations.append(f"`{global_root}` with `--local-configuration` set")
if relative_local_root is not None:
local_root = global_root / relative_local_root
if working_directory == local_root:
return
candidate_locations.append(f"`{local_root}`")
valid_locations = " or from ".join(candidate_locations)
raise ValueError(
f"Infer must run from {valid_locations}. "
f"Cannot run from current working directory `{working_directory}`."
)
def _run_infer_command_get_output(command: Sequence[str]) -> str:
with backend_arguments.backend_log_file(prefix="pyre_infer") as log_file:
with start.background_logging(Path(log_file.name)):
result = subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=log_file.file,
universal_newlines=True,
)
return_code = result.returncode
# Interpretation of the return code needs to be kept in sync with
# `source/command/inferCommand.ml`.
if return_code == 0:
return result.stdout
elif return_code == 1:
raise commands.ClientException(
message="Pyre encountered an internal failure",
exit_code=commands.ExitCode.FAILURE,
)
elif return_code == 2:
raise commands.ClientException(
message="Pyre encountered a failure within buck.",
exit_code=commands.ExitCode.BUCK_INTERNAL_ERROR,
)
elif return_code == 3:
raise commands.ClientException(
message="Pyre encountered an error when building the buck targets.",
exit_code=commands.ExitCode.BUCK_USER_ERROR,
)
else:
raise commands.ClientException(
message=(
"Infer command exited with unexpected return code: "
f"{return_code}."
),
exit_code=commands.ExitCode.FAILURE,
)
def _get_infer_command_output(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> str:
binary_location = configuration.get_binary_respecting_override()
if binary_location is None:
raise configuration_module.InvalidConfiguration(
"Cannot locate a Pyre binary to run."
)
with create_infer_arguments_and_cleanup(
configuration, infer_arguments
) as arguments:
with backend_arguments.temporary_argument_file(arguments) as argument_file_path:
infer_command = [binary_location, "newinfer", str(argument_file_path)]
return _run_infer_command_get_output(command=infer_command)
def _load_output(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> str:
if infer_arguments.read_stdin:
return sys.stdin.read()
else:
return _get_infer_command_output(configuration, infer_arguments)
def _relativize_path(path: str, against: Path) -> Optional[str]:
given_path = Path(path)
return (
None
if against not in given_path.parents
else str(given_path.relative_to(against))
)
def create_module_annotations(
infer_output: RawInferOutput, base_path: Path, options: StubGenerationOptions
) -> List[ModuleAnnotations]:
infer_output_relativized: Dict[Optional[str], RawInferOutputForPath] = {
_relativize_path(path, against=base_path): data
for path, data in infer_output.split_by_path().items()
}
return [
ModuleAnnotations.from_infer_output(
path=path,
infer_output=infer_output_for_path,
options=options,
)
for path, infer_output_for_path in infer_output_relativized.items()
if path is not None
]
def _print_inferences(
infer_output: RawInferOutput, module_annotations: Sequence[ModuleAnnotations]
) -> None:
LOG.log(log.SUCCESS, "Raw Infer Outputs:")
# pyre-ignore[16]: Pyre does not understand `dataclasses_json`
LOG.log(log.SUCCESS, json.dumps(infer_output.to_dict(), indent=2))
LOG.log(log.SUCCESS, "Generated Stubs:")
LOG.log(
log.SUCCESS,
"\n\n".join(
f"*{module.path}*\n{module.to_stubs()}" for module in module_annotations
),
)
def _get_type_directory(log_directory: Path) -> Path:
return log_directory / "types"
def _write_stubs(
type_directory: Path, module_annotations: Sequence[ModuleAnnotations]
) -> None:
if type_directory.exists():
LOG.log(log.SUCCESS, f"Deleting {type_directory}")
shutil.rmtree(type_directory)
type_directory.mkdir(parents=True, exist_ok=True)
LOG.log(log.SUCCESS, f"Outputting inferred stubs to {type_directory}...")
for module in module_annotations:
module.write_stubs(type_directory=type_directory)
def should_annotate_in_place(
path: Path,
paths_to_modify: Optional[Set[Path]],
) -> bool:
return (
True
if paths_to_modify is None
else any(path in paths_to_modify for path in (path, *path.parents))
)
def _annotate_in_place(
working_directory: Path,
type_directory: Path,
paths_to_modify: Optional[Set[Path]],
options: StubGenerationOptions,
number_of_workers: int,
) -> None:
tasks: List[AnnotateModuleInPlace] = []
for full_stub_path in type_directory.rglob("*.pyi"):
relative_stub_path = full_stub_path.relative_to(type_directory)
relative_code_path = relative_stub_path.with_suffix(".py")
full_code_path = working_directory / relative_code_path
if should_annotate_in_place(full_code_path, paths_to_modify):
tasks.append(
AnnotateModuleInPlace(
full_stub_path=str(full_stub_path),
full_code_path=str(full_code_path),
options=options,
)
)
with multiprocessing.Pool(number_of_workers) as pool:
for _ in pool.imap_unordered(AnnotateModuleInPlace.run_task, tasks):
pass
def run_infer(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> commands.ExitCode:
working_directory = infer_arguments.working_directory
_check_working_directory(
working_directory=working_directory,
global_root=Path(configuration.project_root),
relative_local_root=configuration.relative_local_root,
)
type_directory = _get_type_directory(Path(configuration.log_directory))
in_place = infer_arguments.in_place
options = StubGenerationOptions(
annotate_attributes=infer_arguments.annotate_attributes,
use_future_annotations=infer_arguments.use_future_annotations,
dequalify=infer_arguments.dequalify,
quote_annotations=infer_arguments.quote_annotations,
simple_annotations=infer_arguments.simple_annotations,
)
if infer_arguments.annotate_from_existing_stubs:
if not in_place:
raise ValueError(
"`--annotate-from-existing-stubs` cannot be used without the"
" `--in-place` flag"
)
_annotate_in_place(
working_directory=working_directory,
type_directory=type_directory,
paths_to_modify=infer_arguments.paths_to_modify,
options=options,
number_of_workers=configuration.get_number_of_workers(),
)
else:
infer_output = RawInferOutput.create_from_json(
json.loads(_load_output(configuration, infer_arguments))[0]
)
module_annotations = create_module_annotations(
infer_output=infer_output,
base_path=working_directory,
options=options,
)
if infer_arguments.print_only:
_print_inferences(infer_output, module_annotations)
else:
_write_stubs(type_directory, module_annotations)
if in_place:
_annotate_in_place(
working_directory=working_directory,
type_directory=type_directory,
paths_to_modify=infer_arguments.paths_to_modify,
options=options,
number_of_workers=configuration.get_number_of_workers(),
)
return commands.ExitCode.SUCCESS
@remote_logging.log_usage(command_name="infer")
def run(
configuration: configuration_module.Configuration,
infer_arguments: command_arguments.InferArguments,
) -> commands.ExitCode:
try:
return run_infer(configuration, infer_arguments)
except commands.ClientException:
raise
except Exception as error:
traceback.print_exc(file=sys.stderr)
raise commands.ClientException(
f"Exception occurred during Pyre infer: {error}"
) from error