tools/upgrade/errors.py (588 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import json
import logging
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import libcst
import libcst.matchers as libcst_matchers
from . import UserError, ast
LOG: logging.Logger = logging.getLogger(__name__)
MAX_LINES_PER_FIXME: int = 4
PyreError = Dict[str, Any]
PathsToErrors = Dict[Path, List[PyreError]]
LineRange = Tuple[int, int]
LineToErrors = Dict[int, List[Dict[str, str]]]
class LineBreakTransformer(libcst.CSTTransformer):
def leave_SimpleWhitespace(
self,
original_node: libcst.SimpleWhitespace,
updated_node: libcst.SimpleWhitespace,
) -> Union[libcst.SimpleWhitespace, libcst.ParenthesizedWhitespace]:
whitespace = original_node.value.replace("\\", "")
if "\n" in whitespace:
first_line = libcst.TrailingWhitespace(
whitespace=libcst.SimpleWhitespace(
value=whitespace.split("\n")[0].rstrip()
),
comment=None,
newline=libcst.Newline(),
)
last_line = libcst.SimpleWhitespace(value=whitespace.split("\n")[1])
return libcst.ParenthesizedWhitespace(
first_line=first_line, empty_lines=[], indent=True, last_line=last_line
)
return updated_node
@staticmethod
def basic_parenthesize(
node: libcst.CSTNode,
whitespace: Optional[libcst.BaseParenthesizableWhitespace] = None,
) -> libcst.CSTNode:
if not hasattr(node, "lpar"):
return node
if whitespace:
return node.with_changes(
lpar=[libcst.LeftParen(whitespace_after=whitespace)],
rpar=[libcst.RightParen()],
)
return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()])
def leave_Assert(
self, original_node: libcst.Assert, updated_node: libcst.Assert
) -> libcst.Assert:
test = updated_node.test
message = updated_node.msg
comma = updated_node.comma
if not test:
return updated_node
if message and isinstance(comma, libcst.Comma):
message = LineBreakTransformer.basic_parenthesize(
message, comma.whitespace_after
)
comma = comma.with_changes(
whitespace_after=libcst.SimpleWhitespace(
value=" ",
)
)
assert_whitespace = updated_node.whitespace_after_assert
if isinstance(assert_whitespace, libcst.ParenthesizedWhitespace):
return updated_node.with_changes(
test=LineBreakTransformer.basic_parenthesize(test, assert_whitespace),
msg=message,
comma=comma,
whitespace_after_assert=libcst.SimpleWhitespace(value=" "),
)
return updated_node.with_changes(
test=LineBreakTransformer.basic_parenthesize(test),
msg=message,
comma=comma,
)
def leave_Assign(
self, original_node: libcst.Assign, updated_node: libcst.Assign
) -> libcst.Assign:
assign_value = updated_node.value
assign_whitespace = updated_node.targets[-1].whitespace_after_equal
if libcst_matchers.matches(
assign_whitespace, libcst_matchers.ParenthesizedWhitespace()
):
adjusted_target = updated_node.targets[-1].with_changes(
whitespace_after_equal=libcst.SimpleWhitespace(value=" ")
)
updated_targets = list(updated_node.targets[:-1])
updated_targets.append(adjusted_target)
return updated_node.with_changes(
targets=tuple(updated_targets),
value=LineBreakTransformer.basic_parenthesize(
assign_value, assign_whitespace
),
)
return updated_node.with_changes(
value=LineBreakTransformer.basic_parenthesize(assign_value)
)
def leave_AnnAssign(
self, original_node: libcst.AnnAssign, updated_node: libcst.AnnAssign
) -> libcst.AnnAssign:
assign_value = updated_node.value
equal = updated_node.equal
if not isinstance(equal, libcst.AssignEqual):
return updated_node
assign_whitespace = equal.whitespace_after
updated_value = (
LineBreakTransformer.basic_parenthesize(assign_value, assign_whitespace)
if assign_value
else None
)
if libcst_matchers.matches(
assign_whitespace, libcst_matchers.ParenthesizedWhitespace()
):
updated_equal = equal.with_changes(
whitespace_after=libcst.SimpleWhitespace(value=" ")
)
return updated_node.with_changes(
equal=updated_equal,
value=updated_value,
)
return updated_node.with_changes(value=updated_value)
def leave_Del(
self, original_node: libcst.Del, updated_node: libcst.Del
) -> libcst.Del:
delete_target = updated_node.target
delete_whitespace = updated_node.whitespace_after_del
if isinstance(delete_whitespace, libcst.ParenthesizedWhitespace):
return updated_node.with_changes(
target=LineBreakTransformer.basic_parenthesize(
delete_target, delete_whitespace
),
whitespace_after_del=libcst.SimpleWhitespace(value=" "),
)
return updated_node.with_changes(
target=LineBreakTransformer.basic_parenthesize(delete_target)
)
def leave_Raise(
self, original_node: libcst.Raise, updated_node: libcst.Raise
) -> libcst.Raise:
exception = updated_node.exc
if not exception:
return updated_node
raise_whitespace = updated_node.whitespace_after_raise
if isinstance(raise_whitespace, libcst.ParenthesizedWhitespace):
return updated_node.with_changes(
exc=LineBreakTransformer.basic_parenthesize(
exception, raise_whitespace
),
whitespace_after_raise=libcst.SimpleWhitespace(value=" "),
)
return updated_node.with_changes(
exc=LineBreakTransformer.basic_parenthesize(exception)
)
def leave_Return(
self, original_node: libcst.Return, updated_node: libcst.Return
) -> libcst.Return:
return_value = updated_node.value
if not return_value:
return updated_node
return_whitespace = updated_node.whitespace_after_return
if isinstance(return_whitespace, libcst.ParenthesizedWhitespace):
return updated_node.with_changes(
value=LineBreakTransformer.basic_parenthesize(
return_value, return_whitespace
),
whitespace_after_return=libcst.SimpleWhitespace(value=" "),
)
return updated_node.with_changes(
value=LineBreakTransformer.basic_parenthesize(return_value)
)
class PartialErrorSuppression(Exception):
def __init__(self, message: str, unsuppressed_paths: List[str]) -> None:
super().__init__(message)
self.unsuppressed_paths: List[str] = unsuppressed_paths
def error_path(error: Dict[str, Any]) -> str:
return error["path"]
class Errors:
@classmethod
def empty(cls) -> "Errors":
return cls([])
@staticmethod
def from_json(
json_string: str,
only_fix_error_code: Optional[int] = None,
from_stdin: bool = False,
) -> "Errors":
try:
errors = json.loads(json_string)
return Errors(_filter_errors(errors, only_fix_error_code))
except json.decoder.JSONDecodeError:
if from_stdin:
raise UserError(
"Received invalid JSON as input. "
"If piping from `pyre check` be sure to use `--output=json`."
)
else:
raise UserError(
"Encountered invalid output when checking for pyre errors: "
f"`{json_string}`."
)
@staticmethod
def from_stdin(only_fix_error_code: Optional[int] = None) -> "Errors":
input = sys.stdin.read()
return Errors.from_json(input, only_fix_error_code, from_stdin=True)
def __init__(self, errors: List[Dict[str, Any]]) -> None:
self.errors: List[Dict[str, Any]] = errors
def __len__(self) -> int:
return len(self.errors)
def __eq__(self, other: "Errors") -> bool:
return self.errors == other.errors
@property
def paths_to_errors(self) -> Dict[str, List[PyreError]]:
return {
path: list(errors)
for path, errors in itertools.groupby(
sorted(self.errors, key=error_path), key=error_path
)
}
def suppress(
self,
comment: Optional[str] = None,
max_line_length: Optional[int] = None,
truncate: bool = False,
unsafe: bool = False,
) -> None:
unsuppressed_paths_and_exceptions = []
for path_to_suppress, errors in self.paths_to_errors.items():
LOG.info("Processing `%s`", path_to_suppress)
try:
path = Path(path_to_suppress)
input = path.read_text()
output = _suppress_errors(
input,
_build_error_map(errors),
comment,
max_line_length
if max_line_length and max_line_length > 0
else None,
truncate,
unsafe,
)
path.write_text(output)
except SkippingGeneratedFileException:
LOG.warning(f"Skipping generated file at {path_to_suppress}")
except SkippingUnparseableFileException:
LOG.warning(f"Skipping unparseable file at {path_to_suppress}")
except LineBreakParsingException:
LOG.warning(
f"Skipping file with unparseable line breaks at {path_to_suppress}"
)
except (ast.UnstableAST, SyntaxError) as exception:
unsuppressed_paths_and_exceptions.append((path_to_suppress, exception))
if unsuppressed_paths_and_exceptions:
exception_messages = "\n".join(
f"{path} - {str(exception)}"
for path, exception in unsuppressed_paths_and_exceptions
)
raise PartialErrorSuppression(
"Could not fully suppress errors due to the following exceptions: "
f"{exception_messages}\n Run with `--unsafe` to suppress anyway.",
[path for path, _ in unsuppressed_paths_and_exceptions],
)
def _filter_errors(
errors: List[Dict[str, Any]], only_fix_error_code: Optional[int] = None
) -> List[Dict[str, Any]]:
if only_fix_error_code is not None:
errors = [error for error in errors if error["code"] == only_fix_error_code]
return errors
def _remove_comment_preamble(lines: List[str]) -> None:
# Deprecated: leaving remove logic until live old-style comments are cleaned up.
while lines:
old_line = lines.pop()
new_line = re.sub(r"# pyre: .*$", "", old_line).rstrip()
if old_line == "" or new_line != "":
# The preamble has ended.
lines.append(new_line)
return
def _add_error_to_line_break_block(lines: List[str], errors: List[List[str]]) -> None:
# Gather unbroken lines.
line_break_block = [lines.pop() for _ in range(0, len(errors))]
line_break_block.reverse()
# Transform line break block to use parenthesis.
indent = len(line_break_block[0]) - len(line_break_block[0].lstrip())
line_break_block = [line[indent:] for line in line_break_block]
statement = "\n".join(line_break_block)
transformed_statement = libcst.Module([]).code_for_node(
cast(
libcst.CSTNode,
libcst.parse_statement(statement).visit(LineBreakTransformer()),
)
)
transformed_lines = transformed_statement.split("\n")
transformed_lines = [" " * indent + line for line in transformed_lines]
# Add to lines.
for line, comment in zip(transformed_lines, errors):
lines.extend(comment)
lines.append(line)
def _split_across_lines(
comment: str, indent: int, max_line_length: Optional[int]
) -> List[str]:
if not max_line_length or len(comment) <= max_line_length:
return [comment]
comment = comment.lstrip()
available_columns = max_line_length - indent - len("# ")
buffered_line = ""
result = []
prefix = " " * indent
for token in comment.split():
if buffered_line and (
len(buffered_line) + len(token) + len(" ") > available_columns
):
# This new token would make the line exceed the limit,
# hence terminate what we have accumulated.
result.append(("{}{}".format(prefix, buffered_line)).rstrip())
# The first line already has a comment token on it, so don't prefix #. For
# the rest, we need to add the comment symbol manually.
prefix = "{}# ".format(" " * indent)
buffered_line = ""
buffered_line = buffered_line + token + " "
result.append(("{}{}".format(prefix, buffered_line)).rstrip())
return result
class SkippingGeneratedFileException(Exception):
pass
class SkippingUnparseableFileException(Exception):
pass
class LineBreakParsingException(Exception):
pass
def _str_to_int(digits: str) -> Optional[int]:
try:
return int(digits)
except ValueError:
return None
def _get_unused_ignore_codes(errors: List[Dict[str, str]]) -> List[int]:
unused_ignore_codes: List[int] = []
ignore_errors = [error for error in errors if error["code"] == "0"]
for error in ignore_errors:
match = re.search(
r"The `pyre-ignore\[(.*?)\]` or `pyre-fixme\[.*?\]`", error["description"]
)
if match:
unused_ignore_codes.extend(
int_code
for int_code in (
_str_to_int(code.strip()) for code in match.group(1).split(",")
)
if int_code is not None
)
unused_ignore_codes.sort()
return unused_ignore_codes
def _remove_unused_ignores(line: str, errors: List[Dict[str, str]]) -> str:
unused_ignore_codes = _get_unused_ignore_codes(errors)
match = re.search(r"pyre-(ignore|fixme) *\[([0-9, ]+)\]", line)
stripped_line = re.sub(r"# pyre-(ignore|fixme).*$", "", line).rstrip()
if not match:
return stripped_line
# One or more codes are specified in the ignore comment.
# Remove only the codes that are erroring as unused.
ignore_codes_string = match.group(2)
ignore_codes = [
int(code.strip()) for code in ignore_codes_string.split(",") if code != ""
]
remaining_ignore_codes = set(ignore_codes) - set(unused_ignore_codes)
if len(remaining_ignore_codes) == 0 or len(unused_ignore_codes) == 0:
return stripped_line
else:
return line.replace(
ignore_codes_string,
", ".join([str(code) for code in remaining_ignore_codes]),
)
def _line_ranges_spanned_by_format_strings(
source: str,
) -> Dict[libcst.CSTNode, LineRange]:
def _code_range_to_line_range(
code_range: libcst._position.CodeRange,
) -> LineRange:
return code_range.start.line, code_range.end.line
try:
wrapper = libcst.metadata.MetadataWrapper(libcst.parse_module(source))
except libcst._exceptions.ParserSyntaxError as exception:
# NOTE: This should not happen. If a file is unparseable for libcst, it
# would probably have been unparseable for Pyre as well. In that case,
# we would not have raised a 404 parse error and not reached here in the
# first place. Still, catch the exception and just skip the special
# handling of format strings.
LOG.warning(
"Not moving out fixmes from f-strings because"
f" libcst failed to parse the file: {exception}"
)
return {}
position_map = wrapper.resolve(libcst.metadata.PositionProvider)
return {
format_string: _code_range_to_line_range(position_map[format_string])
for format_string in libcst_matchers.findall(
wrapper.module, libcst_matchers.FormattedString()
)
}
def _map_line_to_start_of_range(line_ranges: List[LineRange]) -> Dict[int, int]:
target_line_map = {}
for start, end in reversed(line_ranges):
for line in range(start, end + 1):
target_line_map[line] = start
return target_line_map
class LineBreakBlock:
error_comments: List[List[str]]
opened_expressions: int
is_active: bool
def __init__(self) -> None:
self.error_comments = []
self.opened_expressions = 0
self.is_active = False
def ready_to_suppress(self) -> bool:
# Line break block has been filled and then ended; errors can be applied.
return not self.is_active and len(self.error_comments) > 0
def process_line(self, line: str, error_comments: List[str]) -> None:
comment_free_line = line.split("#")[0].rstrip()
if not self.is_active:
# Check if line break block is beginning.
self.is_active = comment_free_line.endswith("\\")
if self.is_active:
self.error_comments.append(error_comments)
return
# Check if line break block is ending.
self.error_comments.append(error_comments)
if comment_free_line.endswith("\\"):
return
if comment_free_line.endswith("("):
self.opened_expressions += 1
return
if comment_free_line.endswith(")"):
self.opened_expressions -= 1
self.is_active = self.opened_expressions > 0
def _lines_after_suppressing_errors(
lines: List[str],
errors: Dict[int, List[Dict[str, str]]],
custom_comment: Optional[str],
max_line_length: Optional[int],
truncate: bool,
) -> List[str]:
new_lines = []
removing_pyre_comments = False
line_break_block = LineBreakBlock()
in_multi_line_string = False
for index, line in enumerate(lines):
if removing_pyre_comments:
stripped = line.lstrip()
if stripped.startswith("#") and not re.match(
r"# *pyre-(ignore|fixme).*$", stripped
):
continue
else:
removing_pyre_comments = False
number = index + 1
if line.startswith("#") and re.match(r"# *@manual=.*$", line):
# Apply suppressions for lines following @manual to current line.
errors[number] = errors[number + 1]
del errors[number + 1]
# Deduplicate errors
error_mapping = {
error["code"] + error["description"]: error
for error in errors.get(number, [])
}
relevant_errors = list(error_mapping.values())
if any(error["code"] == "0" for error in relevant_errors):
replacement = _remove_unused_ignores(line, relevant_errors)
if replacement == "":
removing_pyre_comments = True
_remove_comment_preamble(new_lines)
continue
else:
line = replacement
indent = len(line) - len(line.lstrip(" "))
comments = [
line
for error in relevant_errors
for line in _error_to_fixme_comment_lines(
error, indent, truncate, max_line_length, custom_comment
)
]
# Handle suppressions in line break blocks.
line_break_block.process_line(line, comments)
if line_break_block.ready_to_suppress():
new_lines.append(line)
try:
line_break_block_errors = line_break_block.error_comments
if sum(len(errors) for errors in line_break_block_errors) > 0:
_add_error_to_line_break_block(new_lines, line_break_block_errors)
except libcst.ParserSyntaxError as exception:
raise LineBreakParsingException(exception)
line_break_block = LineBreakBlock()
continue
# Handle suppressions around multi-line strings.
contains_multi_line_string_token = line.count('"""') % 2 != 0
if contains_multi_line_string_token:
is_end_of_multi_line_string = in_multi_line_string
in_multi_line_string = not in_multi_line_string
else:
is_end_of_multi_line_string = False
if is_end_of_multi_line_string and len(comments) > 0:
# Use a simple same-line suppression for errors on a multi-line string close
error_codes = [
error["code"] for error in relevant_errors if error["code"] != "0"
]
line = line + " # pyre-fixme[{}]".format(", ".join(error_codes))
new_lines.append(line)
continue
# Add suppression comments.
if not line_break_block.is_active and len(comments) > 0:
LOG.info(
"Adding comment%s on line %d: %s",
"s" if len(comments) > 1 else "",
number,
" \n".join(comments),
)
new_lines.extend(comments)
new_lines.append(line)
return new_lines
def _relocate_errors(
errors: LineToErrors, target_line_map: Dict[int, int]
) -> LineToErrors:
relocated = defaultdict(list)
for line, errors in errors.items():
target_line = target_line_map.get(line)
if target_line is None or target_line == line:
target_line = line
else:
LOG.info(
f"Relocating the following fixmes from line {line}"
f" to line {target_line} because line {line} is within"
f" a multi-line format string:\n{errors}"
)
relocated[target_line].extend(errors)
return relocated
def _relocate_errors_inside_format_strings(
errors: LineToErrors, source: str
) -> LineToErrors:
def _expression_to_string(expression: libcst.BaseExpression) -> str:
return libcst.Module(
[libcst.SimpleStatementLine([libcst.Expr(expression)])]
).code.strip()
format_string_line_ranges = _line_ranges_spanned_by_format_strings(source)
if len(format_string_line_ranges) == 0:
return errors
log_lines = ["Lines spanned by format strings:"]
for format_string, line_range in format_string_line_ranges.items():
# pyre-fixme[6]: Expected BaseExpression but got CSTNode.
log_lines.append(f"{_expression_to_string(format_string)}: {line_range}")
LOG.debug("\n".join(log_lines))
return _relocate_errors(
errors, _map_line_to_start_of_range(list(format_string_line_ranges.values()))
)
def _suppress_errors(
input: str,
errors: LineToErrors,
custom_comment: Optional[str] = None,
max_line_length: Optional[int] = None,
truncate: bool = False,
unsafe: bool = False,
) -> str:
if not unsafe and "@" "generated" in input:
raise SkippingGeneratedFileException()
lines: List[str] = input.split("\n")
# Do not suppress parse errors.
if any(
error["code"] == "404" for error_list in errors.values() for error in error_list
):
raise SkippingUnparseableFileException()
errors = _relocate_errors_inside_format_strings(errors, input)
new_lines = _lines_after_suppressing_errors(
lines, errors, custom_comment, max_line_length, truncate
)
output = "\n".join(new_lines)
if not unsafe:
ast.check_stable(input, output)
return output
def _error_to_fixme_comment_lines(
error: Dict[str, Any],
indent: int,
truncate: bool,
max_line_length: Optional[int],
custom_comment: Optional[str],
) -> List[str]:
if error["code"] == "0":
return []
description = custom_comment if custom_comment else error["description"]
comment = "{}# pyre-fixme[{}]: {}".format(" " * indent, error["code"], description)
if not max_line_length:
return [comment]
truncated_comment = comment[: (max_line_length - 3)] + "..."
split_comment_lines = _split_across_lines(comment, indent, max_line_length)
should_truncate = (
truncate
or len(split_comment_lines) > MAX_LINES_PER_FIXME
or any(len(line) > max_line_length for line in split_comment_lines)
)
return [truncated_comment] if should_truncate else split_comment_lines
def _build_error_map(
errors: Iterable[Dict[str, Any]]
) -> Dict[int, List[Dict[str, str]]]:
error_map = defaultdict(lambda: [])
for error in errors:
if error["concise_description"]:
description = error["concise_description"]
else:
description = error["description"]
match = re.search(r"\[(\d+)\]: (.*)", description)
if match:
error_map[error["line"]].append(
{"code": match.group(1), "description": match.group(2)}
)
return error_map