client/statistics_collectors.py (331 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
from collections import defaultdict
from enum import Enum
from re import compile
from typing import Iterable, Any, Dict, List, Pattern, Sequence
import libcst as cst
from libcst.metadata import CodeRange, PositionProvider
@dataclasses.dataclass(frozen=True)
class AnnotationInfo:
node: cst.CSTNode
is_annotated: bool
code_range: CodeRange
class FunctionAnnotationKind(Enum):
NOT_ANNOTATED = 0
PARTIALLY_ANNOTATED = 1
FULLY_ANNOTATED = 2
@staticmethod
def from_function_data(
is_return_annotated: bool,
annotated_parameter_count: int,
is_method_or_classmethod: bool,
parameters: Sequence[cst.Param],
) -> "FunctionAnnotationKind":
if is_return_annotated and annotated_parameter_count == len(parameters):
return FunctionAnnotationKind.FULLY_ANNOTATED
if is_return_annotated:
return FunctionAnnotationKind.PARTIALLY_ANNOTATED
has_untyped_self_parameter = is_method_or_classmethod and (
len(parameters) > 0 and parameters[0].annotation is None
)
# Note: Untyped self parameters don't count towards making the function
# partially-annotated. This is because, if there is no return type, we
# will skip typechecking that function. So, even though `self` is
# considered an implicitly-annotated parameter, we expect at least one
# explicitly-annotated parameter for the function to be typechecked.
threshold_for_partial_annotation = 1 if has_untyped_self_parameter else 0
if annotated_parameter_count > threshold_for_partial_annotation:
return FunctionAnnotationKind.PARTIALLY_ANNOTATED
return FunctionAnnotationKind.NOT_ANNOTATED
@dataclasses.dataclass(frozen=True)
class FunctionAnnotationInfo:
node: cst.CSTNode
annotation_kind: FunctionAnnotationKind
code_range: CodeRange
returns: AnnotationInfo
parameters: List[AnnotationInfo]
is_method_or_classmethod: bool
def non_self_cls_parameters(self) -> Iterable[AnnotationInfo]:
if self.is_method_or_classmethod:
yield from self.parameters[1:]
else:
yield from self.parameters
@property
def is_annotated(self) -> bool:
return self.annotation_kind != FunctionAnnotationKind.NOT_ANNOTATED
@property
def is_partially_annotated(self) -> bool:
return self.annotation_kind == FunctionAnnotationKind.PARTIALLY_ANNOTATED
@property
def is_fully_annotated(self) -> bool:
return self.annotation_kind == FunctionAnnotationKind.FULLY_ANNOTATED
class AnnotationCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)
path: str = ""
def __init__(self) -> None:
self.globals: List[AnnotationInfo] = []
self.attributes: List[AnnotationInfo] = []
self.functions: List[FunctionAnnotationInfo] = []
self.class_definition_depth = 0
self.function_definition_depth = 0
self.static_function_definition_depth = 0
self.line_count = 0
def returns(self) -> Iterable[AnnotationInfo]:
for function in self.functions:
yield function.returns
def parameters(self) -> Iterable[AnnotationInfo]:
for function in self.functions:
yield from function.non_self_cls_parameters()
def in_class_definition(self) -> bool:
return self.class_definition_depth > 0
def in_function_definition(self) -> bool:
return self.function_definition_depth > 0
def in_static_function_definition(self) -> bool:
return self.static_function_definition_depth > 0
def _is_method_or_classmethod(self) -> bool:
return self.in_class_definition() and not self.in_static_function_definition()
def _is_self_or_cls(self, index: int) -> bool:
return index == 0 and self._is_method_or_classmethod()
def _code_range(self, node: cst.CSTNode) -> CodeRange:
return self.get_metadata(PositionProvider, node)
def _parameter_annotations(
self, parameters: Sequence[cst.Param]
) -> Iterable[AnnotationInfo]:
for index, parameter in enumerate(parameters):
is_annotated = parameter.annotation is not None or self._is_self_or_cls(
index
)
yield AnnotationInfo(parameter, is_annotated, self._code_range(parameter))
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
for decorator in node.decorators:
decorator_node = decorator.decorator
if isinstance(decorator_node, cst.Name):
if decorator_node.value == "staticmethod":
self.static_function_definition_depth += 1
break
self.function_definition_depth += 1
returns = AnnotationInfo(
node,
is_annotated=node.returns is not None,
code_range=self._code_range(node.name),
)
parameters = []
annotated_parameter_count = 0
for parameter_info in self._parameter_annotations(node.params.params):
if parameter_info.is_annotated:
annotated_parameter_count += 1
parameters.append(parameter_info)
annotation_kind = FunctionAnnotationKind.from_function_data(
returns.is_annotated,
annotated_parameter_count,
self._is_method_or_classmethod(),
parameters=node.params.params,
)
self.functions.append(
FunctionAnnotationInfo(
node,
annotation_kind,
self._code_range(node),
returns,
parameters,
self._is_method_or_classmethod(),
)
)
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.function_definition_depth -= 1
for decorator in original_node.decorators:
decorator_node = decorator.decorator
if isinstance(decorator_node, cst.Name):
if decorator_node.value == "staticmethod":
self.static_function_definition_depth -= 1
break
def visit_Assign(self, node: cst.Assign) -> None:
if self.in_function_definition():
return
implicitly_annotated_literal = False
if isinstance(node.value, cst.BaseNumber) or isinstance(
node.value, cst.BaseString
):
implicitly_annotated_literal = True
implicitly_annotated_value = False
if isinstance(node.value, cst.Name) or isinstance(node.value, cst.Call):
# An over-approximation of global values that do not need an explicit
# annotation. Erring on the side of reporting these as annotated to
# avoid showing false positives to users.
implicitly_annotated_value = True
code_range = self._code_range(node)
if self.in_class_definition():
is_annotated = implicitly_annotated_literal or implicitly_annotated_value
self.attributes.append(AnnotationInfo(node, is_annotated, code_range))
else:
is_annotated = implicitly_annotated_literal or implicitly_annotated_value
self.globals.append(AnnotationInfo(node, is_annotated, code_range))
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
if self.in_function_definition():
return
code_range = self._code_range(node)
if self.in_class_definition():
self.attributes.append(AnnotationInfo(node, True, code_range))
else:
self.globals.append(AnnotationInfo(node, True, code_range))
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.class_definition_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.class_definition_depth -= 1
def leave_Module(self, original_node: cst.Module) -> None:
file_range = self.get_metadata(PositionProvider, original_node)
if original_node.has_trailing_newline:
self.line_count = file_range.end.line
else:
# Seems to be a quirk in LibCST, the module CodeRange still goes 1 over
# even when there is no trailing new line in the file.
self.line_count = file_range.end.line - 1
class StatisticsCollector(cst.CSTVisitor):
def build_json(self) -> Dict[str, int]:
return {}
class AnnotationCountCollector(StatisticsCollector, AnnotationCollector):
def return_count(self) -> int:
return len(list(self.returns()))
def annotated_return_count(self) -> int:
return len([r for r in self.returns() if r.is_annotated])
def globals_count(self) -> int:
return len(self.globals)
def annotated_globals_count(self) -> int:
return len([g for g in self.globals if g.is_annotated])
def parameters_count(self) -> int:
return len(list(self.parameters()))
def annotated_parameters_count(self) -> int:
return len([p for p in self.parameters() if p.is_annotated])
def attributes_count(self) -> int:
return len(self.attributes)
def annotated_attributes_count(self) -> int:
return len([a for a in self.attributes if a.is_annotated])
def function_count(self) -> int:
return len(self.functions)
def partially_annotated_functions_count(self) -> int:
return len([f for f in self.functions if f.is_partially_annotated])
def fully_annotated_functions_count(self) -> int:
return len([f for f in self.functions if f.is_fully_annotated])
def build_json(self) -> Dict[str, int]:
return {
"return_count": self.return_count(),
"annotated_return_count": self.annotated_return_count(),
"globals_count": self.globals_count(),
"annotated_globals_count": self.annotated_globals_count(),
"parameter_count": self.parameters_count(),
"annotated_parameter_count": self.annotated_parameters_count(),
"attribute_count": self.attributes_count(),
"annotated_attribute_count": self.annotated_attributes_count(),
"function_count": self.function_count(),
"partially_annotated_function_count": (
self.partially_annotated_functions_count()
),
"fully_annotated_function_count": self.fully_annotated_functions_count(),
"line_count": self.line_count,
}
class CountCollector(StatisticsCollector):
def __init__(self, regex: str) -> None:
self.counts: Dict[str, int] = defaultdict(int)
self.regex: Pattern[str] = compile(regex)
def visit_Comment(self, node: cst.Comment) -> None:
match = self.regex.match(node.value)
if match:
code_group = match.group(1)
if code_group:
codes = code_group.strip("[] ").split(",")
else:
codes = ["No Code"]
for code in codes:
self.counts[code.strip()] += 1
def build_json(self) -> Dict[str, int]:
return dict(self.counts)
class FixmeCountCollector(CountCollector):
def __init__(self) -> None:
super().__init__(r".*# *pyre-fixme(\[(\d* *,? *)*\])?")
class IgnoreCountCollector(CountCollector):
def __init__(self) -> None:
super().__init__(r".*# *pyre-ignore(\[(\d* *,? *)*\])?")
class StrictCountCollector(StatisticsCollector):
def __init__(self, strict_by_default: bool) -> None:
self.is_strict: bool = False
self.is_unsafe: bool = False
self.strict_count: int = 0
self.unsafe_count: int = 0
self.strict_by_default: bool = strict_by_default
self.unsafe_regex: Pattern[str] = compile(r" ?#+ *pyre-unsafe")
self.strict_regex: Pattern[str] = compile(r" ?#+ *pyre-strict")
self.ignore_all_regex: Pattern[str] = compile(r" ?#+ *pyre-ignore-all-errors")
self.ignore_all_by_code_regex: Pattern[str] = compile(
r" ?#+ *pyre-ignore-all-errors\[[0-9]+[0-9, ]*\]"
)
def is_unsafe_module(self) -> bool:
if self.is_unsafe:
return True
elif self.is_strict or self.strict_by_default:
return False
return True
def is_strict_module(self) -> bool:
return not self.is_unsafe_module()
def visit_Module(self, node: cst.Module) -> None:
self.is_strict = False
self.is_unsafe = False
def visit_Comment(self, node: cst.Comment) -> None:
if self.strict_regex.match(node.value):
self.is_strict = True
return
if self.unsafe_regex.match(node.value):
self.is_unsafe = True
return
if self.ignore_all_regex.match(
node.value
) and not self.ignore_all_by_code_regex.match(node.value):
self.is_unsafe = True
def leave_Module(self, original_node: cst.Module) -> None:
if self.is_unsafe_module():
self.unsafe_count += 1
else:
self.strict_count += 1
def build_json(self) -> Dict[str, int]:
return {"unsafe_count": self.unsafe_count, "strict_count": self.strict_count}
class CodeQualityIssue:
def __init__(
self, code_range: CodeRange, path: str, category: str, message: str
) -> None:
self.category: str = category
self.detail_message: str = message
self.line: int = code_range.start.line
self.line_end: int = code_range.end.line
self.column: int = code_range.start.column
self.column_end: int = code_range.end.column
self.path: str = path
def build_json(self) -> Dict[str, Any]:
return {
"category": self.category,
"detail_message": self.detail_message,
"line": self.line,
"line_end": self.line_end,
"column": self.column,
"column_end": self.column_end,
"path": self.path,
}
class FunctionsCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)
path: str = ""
def __init__(self) -> None:
self.issues: List[CodeQualityIssue] = []
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
return_is_annotated = node.returns is not None
if not return_is_annotated:
code_range = self.get_metadata(PositionProvider, node)
issue = CodeQualityIssue(
code_range,
self.path,
"PYRE_MISSING_ANNOTATIONS",
"This function is missing a return annotation. \
Bodies of unannotated functions are not typechecked by Pyre.",
)
self.issues.append(issue)
class StrictIssueCollector(StrictCountCollector):
METADATA_DEPENDENCIES = (PositionProvider,)
path: str = ""
issues: List[CodeQualityIssue] = []
def _create_issue(self, node: cst.Module) -> None:
file_range = self.get_metadata(PositionProvider, node)
code_range = CodeRange(start=file_range.start, end=file_range.start)
issue = CodeQualityIssue(
code_range, self.path, "PYRE_STRICT", "Unsafe Pyre file."
)
self.issues.append(issue)
def leave_Module(self, original_node: cst.Module) -> None:
if self.is_unsafe_module():
self._create_issue(original_node)