tools/upgrade/errors.py (588 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import itertools import json import logging import re import sys from collections import defaultdict from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast import libcst import libcst.matchers as libcst_matchers from . import UserError, ast LOG: logging.Logger = logging.getLogger(__name__) MAX_LINES_PER_FIXME: int = 4 PyreError = Dict[str, Any] PathsToErrors = Dict[Path, List[PyreError]] LineRange = Tuple[int, int] LineToErrors = Dict[int, List[Dict[str, str]]] class LineBreakTransformer(libcst.CSTTransformer): def leave_SimpleWhitespace( self, original_node: libcst.SimpleWhitespace, updated_node: libcst.SimpleWhitespace, ) -> Union[libcst.SimpleWhitespace, libcst.ParenthesizedWhitespace]: whitespace = original_node.value.replace("\\", "") if "\n" in whitespace: first_line = libcst.TrailingWhitespace( whitespace=libcst.SimpleWhitespace( value=whitespace.split("\n")[0].rstrip() ), comment=None, newline=libcst.Newline(), ) last_line = libcst.SimpleWhitespace(value=whitespace.split("\n")[1]) return libcst.ParenthesizedWhitespace( first_line=first_line, empty_lines=[], indent=True, last_line=last_line ) return updated_node @staticmethod def basic_parenthesize( node: libcst.CSTNode, whitespace: Optional[libcst.BaseParenthesizableWhitespace] = None, ) -> libcst.CSTNode: if not hasattr(node, "lpar"): return node if whitespace: return node.with_changes( lpar=[libcst.LeftParen(whitespace_after=whitespace)], rpar=[libcst.RightParen()], ) return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()]) def leave_Assert( self, original_node: libcst.Assert, updated_node: libcst.Assert ) -> libcst.Assert: test = updated_node.test message = updated_node.msg comma = updated_node.comma if not test: return updated_node if message and isinstance(comma, libcst.Comma): message = LineBreakTransformer.basic_parenthesize( message, comma.whitespace_after ) comma = comma.with_changes( whitespace_after=libcst.SimpleWhitespace( value=" ", ) ) assert_whitespace = updated_node.whitespace_after_assert if isinstance(assert_whitespace, libcst.ParenthesizedWhitespace): return updated_node.with_changes( test=LineBreakTransformer.basic_parenthesize(test, assert_whitespace), msg=message, comma=comma, whitespace_after_assert=libcst.SimpleWhitespace(value=" "), ) return updated_node.with_changes( test=LineBreakTransformer.basic_parenthesize(test), msg=message, comma=comma, ) def leave_Assign( self, original_node: libcst.Assign, updated_node: libcst.Assign ) -> libcst.Assign: assign_value = updated_node.value assign_whitespace = updated_node.targets[-1].whitespace_after_equal if libcst_matchers.matches( assign_whitespace, libcst_matchers.ParenthesizedWhitespace() ): adjusted_target = updated_node.targets[-1].with_changes( whitespace_after_equal=libcst.SimpleWhitespace(value=" ") ) updated_targets = list(updated_node.targets[:-1]) updated_targets.append(adjusted_target) return updated_node.with_changes( targets=tuple(updated_targets), value=LineBreakTransformer.basic_parenthesize( assign_value, assign_whitespace ), ) return updated_node.with_changes( value=LineBreakTransformer.basic_parenthesize(assign_value) ) def leave_AnnAssign( self, original_node: libcst.AnnAssign, updated_node: libcst.AnnAssign ) -> libcst.AnnAssign: assign_value = updated_node.value equal = updated_node.equal if not isinstance(equal, libcst.AssignEqual): return updated_node assign_whitespace = equal.whitespace_after updated_value = ( LineBreakTransformer.basic_parenthesize(assign_value, assign_whitespace) if assign_value else None ) if libcst_matchers.matches( assign_whitespace, libcst_matchers.ParenthesizedWhitespace() ): updated_equal = equal.with_changes( whitespace_after=libcst.SimpleWhitespace(value=" ") ) return updated_node.with_changes( equal=updated_equal, value=updated_value, ) return updated_node.with_changes(value=updated_value) def leave_Del( self, original_node: libcst.Del, updated_node: libcst.Del ) -> libcst.Del: delete_target = updated_node.target delete_whitespace = updated_node.whitespace_after_del if isinstance(delete_whitespace, libcst.ParenthesizedWhitespace): return updated_node.with_changes( target=LineBreakTransformer.basic_parenthesize( delete_target, delete_whitespace ), whitespace_after_del=libcst.SimpleWhitespace(value=" "), ) return updated_node.with_changes( target=LineBreakTransformer.basic_parenthesize(delete_target) ) def leave_Raise( self, original_node: libcst.Raise, updated_node: libcst.Raise ) -> libcst.Raise: exception = updated_node.exc if not exception: return updated_node raise_whitespace = updated_node.whitespace_after_raise if isinstance(raise_whitespace, libcst.ParenthesizedWhitespace): return updated_node.with_changes( exc=LineBreakTransformer.basic_parenthesize( exception, raise_whitespace ), whitespace_after_raise=libcst.SimpleWhitespace(value=" "), ) return updated_node.with_changes( exc=LineBreakTransformer.basic_parenthesize(exception) ) def leave_Return( self, original_node: libcst.Return, updated_node: libcst.Return ) -> libcst.Return: return_value = updated_node.value if not return_value: return updated_node return_whitespace = updated_node.whitespace_after_return if isinstance(return_whitespace, libcst.ParenthesizedWhitespace): return updated_node.with_changes( value=LineBreakTransformer.basic_parenthesize( return_value, return_whitespace ), whitespace_after_return=libcst.SimpleWhitespace(value=" "), ) return updated_node.with_changes( value=LineBreakTransformer.basic_parenthesize(return_value) ) class PartialErrorSuppression(Exception): def __init__(self, message: str, unsuppressed_paths: List[str]) -> None: super().__init__(message) self.unsuppressed_paths: List[str] = unsuppressed_paths def error_path(error: Dict[str, Any]) -> str: return error["path"] class Errors: @classmethod def empty(cls) -> "Errors": return cls([]) @staticmethod def from_json( json_string: str, only_fix_error_code: Optional[int] = None, from_stdin: bool = False, ) -> "Errors": try: errors = json.loads(json_string) return Errors(_filter_errors(errors, only_fix_error_code)) except json.decoder.JSONDecodeError: if from_stdin: raise UserError( "Received invalid JSON as input. " "If piping from `pyre check` be sure to use `--output=json`." ) else: raise UserError( "Encountered invalid output when checking for pyre errors: " f"`{json_string}`." ) @staticmethod def from_stdin(only_fix_error_code: Optional[int] = None) -> "Errors": input = sys.stdin.read() return Errors.from_json(input, only_fix_error_code, from_stdin=True) def __init__(self, errors: List[Dict[str, Any]]) -> None: self.errors: List[Dict[str, Any]] = errors def __len__(self) -> int: return len(self.errors) def __eq__(self, other: "Errors") -> bool: return self.errors == other.errors @property def paths_to_errors(self) -> Dict[str, List[PyreError]]: return { path: list(errors) for path, errors in itertools.groupby( sorted(self.errors, key=error_path), key=error_path ) } def suppress( self, comment: Optional[str] = None, max_line_length: Optional[int] = None, truncate: bool = False, unsafe: bool = False, ) -> None: unsuppressed_paths_and_exceptions = [] for path_to_suppress, errors in self.paths_to_errors.items(): LOG.info("Processing `%s`", path_to_suppress) try: path = Path(path_to_suppress) input = path.read_text() output = _suppress_errors( input, _build_error_map(errors), comment, max_line_length if max_line_length and max_line_length > 0 else None, truncate, unsafe, ) path.write_text(output) except SkippingGeneratedFileException: LOG.warning(f"Skipping generated file at {path_to_suppress}") except SkippingUnparseableFileException: LOG.warning(f"Skipping unparseable file at {path_to_suppress}") except LineBreakParsingException: LOG.warning( f"Skipping file with unparseable line breaks at {path_to_suppress}" ) except (ast.UnstableAST, SyntaxError) as exception: unsuppressed_paths_and_exceptions.append((path_to_suppress, exception)) if unsuppressed_paths_and_exceptions: exception_messages = "\n".join( f"{path} - {str(exception)}" for path, exception in unsuppressed_paths_and_exceptions ) raise PartialErrorSuppression( "Could not fully suppress errors due to the following exceptions: " f"{exception_messages}\n Run with `--unsafe` to suppress anyway.", [path for path, _ in unsuppressed_paths_and_exceptions], ) def _filter_errors( errors: List[Dict[str, Any]], only_fix_error_code: Optional[int] = None ) -> List[Dict[str, Any]]: if only_fix_error_code is not None: errors = [error for error in errors if error["code"] == only_fix_error_code] return errors def _remove_comment_preamble(lines: List[str]) -> None: # Deprecated: leaving remove logic until live old-style comments are cleaned up. while lines: old_line = lines.pop() new_line = re.sub(r"# pyre: .*$", "", old_line).rstrip() if old_line == "" or new_line != "": # The preamble has ended. lines.append(new_line) return def _add_error_to_line_break_block(lines: List[str], errors: List[List[str]]) -> None: # Gather unbroken lines. line_break_block = [lines.pop() for _ in range(0, len(errors))] line_break_block.reverse() # Transform line break block to use parenthesis. indent = len(line_break_block[0]) - len(line_break_block[0].lstrip()) line_break_block = [line[indent:] for line in line_break_block] statement = "\n".join(line_break_block) transformed_statement = libcst.Module([]).code_for_node( cast( libcst.CSTNode, libcst.parse_statement(statement).visit(LineBreakTransformer()), ) ) transformed_lines = transformed_statement.split("\n") transformed_lines = [" " * indent + line for line in transformed_lines] # Add to lines. for line, comment in zip(transformed_lines, errors): lines.extend(comment) lines.append(line) def _split_across_lines( comment: str, indent: int, max_line_length: Optional[int] ) -> List[str]: if not max_line_length or len(comment) <= max_line_length: return [comment] comment = comment.lstrip() available_columns = max_line_length - indent - len("# ") buffered_line = "" result = [] prefix = " " * indent for token in comment.split(): if buffered_line and ( len(buffered_line) + len(token) + len(" ") > available_columns ): # This new token would make the line exceed the limit, # hence terminate what we have accumulated. result.append(("{}{}".format(prefix, buffered_line)).rstrip()) # The first line already has a comment token on it, so don't prefix #. For # the rest, we need to add the comment symbol manually. prefix = "{}# ".format(" " * indent) buffered_line = "" buffered_line = buffered_line + token + " " result.append(("{}{}".format(prefix, buffered_line)).rstrip()) return result class SkippingGeneratedFileException(Exception): pass class SkippingUnparseableFileException(Exception): pass class LineBreakParsingException(Exception): pass def _str_to_int(digits: str) -> Optional[int]: try: return int(digits) except ValueError: return None def _get_unused_ignore_codes(errors: List[Dict[str, str]]) -> List[int]: unused_ignore_codes: List[int] = [] ignore_errors = [error for error in errors if error["code"] == "0"] for error in ignore_errors: match = re.search( r"The `pyre-ignore\[(.*?)\]` or `pyre-fixme\[.*?\]`", error["description"] ) if match: unused_ignore_codes.extend( int_code for int_code in ( _str_to_int(code.strip()) for code in match.group(1).split(",") ) if int_code is not None ) unused_ignore_codes.sort() return unused_ignore_codes def _remove_unused_ignores(line: str, errors: List[Dict[str, str]]) -> str: unused_ignore_codes = _get_unused_ignore_codes(errors) match = re.search(r"pyre-(ignore|fixme) *\[([0-9, ]+)\]", line) stripped_line = re.sub(r"# pyre-(ignore|fixme).*$", "", line).rstrip() if not match: return stripped_line # One or more codes are specified in the ignore comment. # Remove only the codes that are erroring as unused. ignore_codes_string = match.group(2) ignore_codes = [ int(code.strip()) for code in ignore_codes_string.split(",") if code != "" ] remaining_ignore_codes = set(ignore_codes) - set(unused_ignore_codes) if len(remaining_ignore_codes) == 0 or len(unused_ignore_codes) == 0: return stripped_line else: return line.replace( ignore_codes_string, ", ".join([str(code) for code in remaining_ignore_codes]), ) def _line_ranges_spanned_by_format_strings( source: str, ) -> Dict[libcst.CSTNode, LineRange]: def _code_range_to_line_range( code_range: libcst._position.CodeRange, ) -> LineRange: return code_range.start.line, code_range.end.line try: wrapper = libcst.metadata.MetadataWrapper(libcst.parse_module(source)) except libcst._exceptions.ParserSyntaxError as exception: # NOTE: This should not happen. If a file is unparseable for libcst, it # would probably have been unparseable for Pyre as well. In that case, # we would not have raised a 404 parse error and not reached here in the # first place. Still, catch the exception and just skip the special # handling of format strings. LOG.warning( "Not moving out fixmes from f-strings because" f" libcst failed to parse the file: {exception}" ) return {} position_map = wrapper.resolve(libcst.metadata.PositionProvider) return { format_string: _code_range_to_line_range(position_map[format_string]) for format_string in libcst_matchers.findall( wrapper.module, libcst_matchers.FormattedString() ) } def _map_line_to_start_of_range(line_ranges: List[LineRange]) -> Dict[int, int]: target_line_map = {} for start, end in reversed(line_ranges): for line in range(start, end + 1): target_line_map[line] = start return target_line_map class LineBreakBlock: error_comments: List[List[str]] opened_expressions: int is_active: bool def __init__(self) -> None: self.error_comments = [] self.opened_expressions = 0 self.is_active = False def ready_to_suppress(self) -> bool: # Line break block has been filled and then ended; errors can be applied. return not self.is_active and len(self.error_comments) > 0 def process_line(self, line: str, error_comments: List[str]) -> None: comment_free_line = line.split("#")[0].rstrip() if not self.is_active: # Check if line break block is beginning. self.is_active = comment_free_line.endswith("\\") if self.is_active: self.error_comments.append(error_comments) return # Check if line break block is ending. self.error_comments.append(error_comments) if comment_free_line.endswith("\\"): return if comment_free_line.endswith("("): self.opened_expressions += 1 return if comment_free_line.endswith(")"): self.opened_expressions -= 1 self.is_active = self.opened_expressions > 0 def _lines_after_suppressing_errors( lines: List[str], errors: Dict[int, List[Dict[str, str]]], custom_comment: Optional[str], max_line_length: Optional[int], truncate: bool, ) -> List[str]: new_lines = [] removing_pyre_comments = False line_break_block = LineBreakBlock() in_multi_line_string = False for index, line in enumerate(lines): if removing_pyre_comments: stripped = line.lstrip() if stripped.startswith("#") and not re.match( r"# *pyre-(ignore|fixme).*$", stripped ): continue else: removing_pyre_comments = False number = index + 1 if line.startswith("#") and re.match(r"# *@manual=.*$", line): # Apply suppressions for lines following @manual to current line. errors[number] = errors[number + 1] del errors[number + 1] # Deduplicate errors error_mapping = { error["code"] + error["description"]: error for error in errors.get(number, []) } relevant_errors = list(error_mapping.values()) if any(error["code"] == "0" for error in relevant_errors): replacement = _remove_unused_ignores(line, relevant_errors) if replacement == "": removing_pyre_comments = True _remove_comment_preamble(new_lines) continue else: line = replacement indent = len(line) - len(line.lstrip(" ")) comments = [ line for error in relevant_errors for line in _error_to_fixme_comment_lines( error, indent, truncate, max_line_length, custom_comment ) ] # Handle suppressions in line break blocks. line_break_block.process_line(line, comments) if line_break_block.ready_to_suppress(): new_lines.append(line) try: line_break_block_errors = line_break_block.error_comments if sum(len(errors) for errors in line_break_block_errors) > 0: _add_error_to_line_break_block(new_lines, line_break_block_errors) except libcst.ParserSyntaxError as exception: raise LineBreakParsingException(exception) line_break_block = LineBreakBlock() continue # Handle suppressions around multi-line strings. contains_multi_line_string_token = line.count('"""') % 2 != 0 if contains_multi_line_string_token: is_end_of_multi_line_string = in_multi_line_string in_multi_line_string = not in_multi_line_string else: is_end_of_multi_line_string = False if is_end_of_multi_line_string and len(comments) > 0: # Use a simple same-line suppression for errors on a multi-line string close error_codes = [ error["code"] for error in relevant_errors if error["code"] != "0" ] line = line + " # pyre-fixme[{}]".format(", ".join(error_codes)) new_lines.append(line) continue # Add suppression comments. if not line_break_block.is_active and len(comments) > 0: LOG.info( "Adding comment%s on line %d: %s", "s" if len(comments) > 1 else "", number, " \n".join(comments), ) new_lines.extend(comments) new_lines.append(line) return new_lines def _relocate_errors( errors: LineToErrors, target_line_map: Dict[int, int] ) -> LineToErrors: relocated = defaultdict(list) for line, errors in errors.items(): target_line = target_line_map.get(line) if target_line is None or target_line == line: target_line = line else: LOG.info( f"Relocating the following fixmes from line {line}" f" to line {target_line} because line {line} is within" f" a multi-line format string:\n{errors}" ) relocated[target_line].extend(errors) return relocated def _relocate_errors_inside_format_strings( errors: LineToErrors, source: str ) -> LineToErrors: def _expression_to_string(expression: libcst.BaseExpression) -> str: return libcst.Module( [libcst.SimpleStatementLine([libcst.Expr(expression)])] ).code.strip() format_string_line_ranges = _line_ranges_spanned_by_format_strings(source) if len(format_string_line_ranges) == 0: return errors log_lines = ["Lines spanned by format strings:"] for format_string, line_range in format_string_line_ranges.items(): # pyre-fixme[6]: Expected BaseExpression but got CSTNode. log_lines.append(f"{_expression_to_string(format_string)}: {line_range}") LOG.debug("\n".join(log_lines)) return _relocate_errors( errors, _map_line_to_start_of_range(list(format_string_line_ranges.values())) ) def _suppress_errors( input: str, errors: LineToErrors, custom_comment: Optional[str] = None, max_line_length: Optional[int] = None, truncate: bool = False, unsafe: bool = False, ) -> str: if not unsafe and "@" "generated" in input: raise SkippingGeneratedFileException() lines: List[str] = input.split("\n") # Do not suppress parse errors. if any( error["code"] == "404" for error_list in errors.values() for error in error_list ): raise SkippingUnparseableFileException() errors = _relocate_errors_inside_format_strings(errors, input) new_lines = _lines_after_suppressing_errors( lines, errors, custom_comment, max_line_length, truncate ) output = "\n".join(new_lines) if not unsafe: ast.check_stable(input, output) return output def _error_to_fixme_comment_lines( error: Dict[str, Any], indent: int, truncate: bool, max_line_length: Optional[int], custom_comment: Optional[str], ) -> List[str]: if error["code"] == "0": return [] description = custom_comment if custom_comment else error["description"] comment = "{}# pyre-fixme[{}]: {}".format(" " * indent, error["code"], description) if not max_line_length: return [comment] truncated_comment = comment[: (max_line_length - 3)] + "..." split_comment_lines = _split_across_lines(comment, indent, max_line_length) should_truncate = ( truncate or len(split_comment_lines) > MAX_LINES_PER_FIXME or any(len(line) > max_line_length for line in split_comment_lines) ) return [truncated_comment] if should_truncate else split_comment_lines def _build_error_map( errors: Iterable[Dict[str, Any]] ) -> Dict[int, List[Dict[str, str]]]: error_map = defaultdict(lambda: []) for error in errors: if error["concise_description"]: description = error["concise_description"] else: description = error["description"] match = re.search(r"\[(\d+)\]: (.*)", description) if match: error_map[error["line"]].append( {"code": match.group(1), "description": match.group(2)} ) return error_map