experimental/piranha_playground/rule_inference/utils/node_utils.py (107 lines of code) (raw):

# Copyright (c) 2023 Uber Technologies, Inc. # # <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file # except in compliance with the License. You may obtain a copy of the License at # <p>http://www.apache.org/licenses/LICENSE-2.0 # # <p>Unless required by applicable law or agreed to in writing, software distributed under the # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. import re from typing import List from tree_sitter import Node, TreeCursor, Tree class NodeUtils: """ NodeUtils is a utility class that provides static methods for performing operations on AST nodes. The methods include generating s-expressions, converting nodes to source code, getting non-overlapping nodes, removing partial nodes, and more. """ @staticmethod def generate_sexpr(node: Node, depth: int = 0, prefix: str = "") -> str: """ Creates a pretty s-expression representation of a given node. :param node: Node to generate the s-expression for. :param depth: Depth of the node in the AST. :param prefix: Prefix string to be appended at the start of the s-expression. :return: The generated s-expression. """ indent = " " * depth cursor: TreeCursor = node.walk() s_exp = indent + f"{prefix}({node.type} " next_child = cursor.goto_first_child() while next_child: child_node: Node = cursor.node if child_node.is_named: s_exp += "\n" prefix = "" if cursor.current_field_name(): prefix = f"{cursor.current_field_name()}: " s_exp += NodeUtils.generate_sexpr(child_node, depth + 1, prefix) elif cursor.current_field_name(): s_exp += "\n" + " " * (depth + 1) s_exp += f'{cursor.current_field_name()}: ("{child_node.type}")' next_child = cursor.goto_next_sibling() return s_exp + ")" @staticmethod def convert_to_source( node: Node, depth: int = 0, exclude: List[Node] = None ) -> str: """ Convert a given node to its source code representation (unified). :param node: Node to convert. :param depth: Depth of the node in the AST. :param exclude: List of nodes to be excluded from the source code. :return: Source code representation of the node. """ if exclude is None: exclude = [] for to_exclude in exclude: if NodeUtils.contains(to_exclude, node): return "{placeholder}" cursor: TreeCursor = node.walk() s_exp = "" has_next_child = cursor.goto_first_child() if not has_next_child: s_exp += node.text.decode("utf8") return s_exp while has_next_child: nxt = NodeUtils.convert_to_source(cursor.node, depth + 1, exclude) s_exp += nxt + " " has_next_child = cursor.goto_next_sibling() return s_exp.strip() @staticmethod def get_smallest_nonoverlapping_set(nodes: List[Node]) -> List[Node]: """ Get the smallest non-overlapping set of nodes from the given list. :param nodes: List of nodes. :return: The smallest non-overlapping set of nodes. """ nodes = sorted( nodes, key=lambda x: (x.start_point, tuple(map(lambda n: -n, x.end_point))) ) # get the smallest non overlapping set of nodes smallest_non_overlapping_set = [] for node in nodes: if not smallest_non_overlapping_set: smallest_non_overlapping_set.append(node) else: if node.start_point > smallest_non_overlapping_set[-1].end_point: smallest_non_overlapping_set.append(node) return smallest_non_overlapping_set @staticmethod def remove_partial_nodes(nodes: List[Node]) -> List[Node]: """ Remove nodes that whose children are not contained in the replacement pair. Until a fixed point is reached where no more nodes can be removed. :param nodes: List of nodes. :return: The updated list of nodes after removing partial nodes. """ while True: new_nodes = [ node for node in nodes if all(child in node.children for child in node.children) ] if len(new_nodes) == len(nodes): break nodes = new_nodes return new_nodes @staticmethod def normalize_code(code: str) -> str: """ Eliminates unnecessary spaces and newline characters from code. This function is as preprocessing step before comparing the refactored code with the target code. :param code: str, Code to normalize. :return: str, Normalized code. """ # replace multiple spaces with a single space code = re.sub(r"\s+", "", code) # replace multiple newlines with a single newline code = re.sub(r"\n+", "", code) # remove spaces before and after newlines code = re.sub(r" ?\n ?", "", code) # remove spaces at the beginning and end of the code code = code.strip() return code @staticmethod def contains(node: Node, other: Node) -> bool: """ Checks if the given node contains the other node. :param node: Node, Node to check if it contains the other node. :param other: Node, Node to check if it is contained by the other node. :return: bool, True if the given node contains the other node, False otherwise. """ return ( node.start_point <= other.start_point and node.end_point >= other.end_point ) @staticmethod def find_lowest_common_ancestor(nodes: List[Node]) -> Node: """ Find the smallest common ancestor of the provided nodes. :param nodes: list of nodes for which to find the smallest common ancestor. :return: Node which is the smallest common ancestor. """ # Ensure the list of nodes isn't empty assert len(nodes) > 0 # Prepare a dictionary to map node's id to the node object ids_to_nodes = {node.id: node for node in nodes} # For each node, follow its parent chain and add each one to the ancestor set and ids_to_nodes map ancestor_ids = [set() for _ in nodes] for i, node in enumerate(nodes): while node is not None: ancestor_ids[i].add(node.id) ids_to_nodes[node.id] = node node = node.parent # Get the intersection of all ancestor sets common_ancestors_ids = set.intersection(*ancestor_ids) # If there are no common ancestors, there's a problem with the input tree if not common_ancestors_ids: raise ValueError("Nodes have no common ancestor") # The LCA is the deepest node, i.e. the one with maximum start_byte max_start_byte_id = max( common_ancestors_ids, key=lambda node_id: ids_to_nodes[node_id].start_byte ) return ids_to_nodes[max_start_byte_id] @staticmethod def get_nodes_in_range(root: Node, start_byte, end_byte): nodes = [] for child in root.children: if start_byte <= child.end_byte and end_byte >= child.start_byte: if start_byte <= child.start_byte and end_byte >= child.end_byte: nodes.append(child) else: nodes += NodeUtils.get_nodes_in_range(child, start_byte, end_byte) return NodeUtils.get_smallest_nonoverlapping_set(nodes)