experimental/piranha_playground/rule_inference/graph_parser.py (54 lines of code) (raw):

# Copyright (c) 2023 Uber Technologies, Inc. # # <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file # except in compliance with the License. You may obtain a copy of the License at # <p>http://www.apache.org/licenses/LICENSE-2.0 # # <p>Unless required by applicable law or agreed to in writing, software distributed under the # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict, deque from typing import Deque, Dict, List, Set, Tuple import attr from tree_sitter import Node, Tree from piranha_playground.rule_inference.utils.node_utils import NodeUtils @attr.s class GraphParser: """ The TemplateParser class performs depth-first search on two given Abstract Syntax Trees (ASTs) to identify mapping from 'before' code snippet (sub-AST) to 'after' code snippet (sub-AST) required to construct the rules of the graph. It also parses special comments that specify the flow between the rules. Each template is associated with a unique identifier, enclosed by line comments that delineate the template's start and end. A sample input format is shown below: Templates: // 1 x = someMethod() // x Edges: // 1 -> 2 // 1 -> 3 :param source_tree: The AST containing source templates. :type source_tree: Tree :param target_tree: The AST containing target templates. :type target_tree: Tree """ source_tree = attr.ib(type=Tree) target_tree = attr.ib(type=Tree) replacement_source = attr.ib( type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list)) ) replacement_target = attr.ib( type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list)) ) edges = attr.ib( type=Dict[str, Set[str]], default=attr.Factory(lambda: defaultdict(set)) ) def parse_templates(self) -> Dict[str, Tuple[List[Node], List[Node]]]: """ Executes the actual tree traversal on both 'source_tree' and 'target_tree'. It finds corresponding template pairs using identifiers specified in the comments. It also finds the edges between the templates. This method returns a dictionary of matched template pairs, which serves as a foundation for subsequent rule inference. :return: A dictionary mapping identifiers to matched template pairs. :rtype: Dict[str, Tuple[List[Node], List[Node]]] """ source_dict = self._traverse_tree(self.source_tree) target_dict = self._traverse_tree(self.target_tree) matching_pairs = {} for comment in source_dict: matching_pairs[comment] = ( source_dict[comment], target_dict.get(comment, []), ) return matching_pairs def _traverse_tree(self, tree: Tree) -> Dict[str, List[Node]]: """ Performs a depth-first search (DFS) on the given tree to find nodes that are used for templates. :param tree: The tree to be traversed. :type tree: Tree :return: A dictionary mapping template identifiers to corresponding nodes. :rtype: Dict[str, List[Node]] """ root = tree.root_node stack: Deque = deque([root]) replacement_pairs = [] node_start = root while stack: node = stack.pop() if "comment" in node.type: node_text = node.text.decode("utf8") if "->" in node_text: x, y = map(str.strip, node_text[2:].split("->")) self.edges[x].add(y) elif node_text.strip().endswith("end"): replacement_pairs.append((node_start, node)) else: node_start = node for child in reversed(node.children): stack.append(child) replacement_dict = defaultdict(list) for start, end in replacement_pairs: begin_capture = start.end_byte end_capture = end.start_byte nodes = NodeUtils.get_nodes_in_range(root, begin_capture, end_capture) comment = start.text.decode("utf8")[2:].strip() replacement_dict[comment] = nodes return replacement_dict