experimental/piranha_playground/rule_inference/static_inference.py (201 lines of code) (raw):

# Copyright (c) 2023 Uber Technologies, Inc. # # <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file # except in compliance with the License. You may obtain a copy of the License at # <p>http://www.apache.org/licenses/LICENSE-2.0 # # <p>Unless required by applicable law or agreed to in writing, software distributed under the # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from typing import Dict, List import attr from comby import Comby from tree_sitter import Node, TreeCursor from piranha_playground.rule_inference.utils.node_utils import NodeUtils from piranha_playground.rule_inference.utils.rule_utils import RawRule @attr.s class ChildProcessingStrategy(ABC): """ Abstract Base Class to define a strategy to process a named child node. """ query_writer = attr.ib(type="QueryWriter") @abstractmethod def process_child(self, cursor: TreeCursor, depth: int) -> str: """This method decides a child of a node should be handled when writing a query. Either expand the query for the child, simply capture it without expanding, etc. :param cursor: The cursor pointing to the child node. :param depth: The depth of the child node in the tree. """ pass @attr.s class SimplifyStrategy(ChildProcessingStrategy): """ Strategy to simplify the processing of a named child node. """ def process_child(self, cursor: TreeCursor, depth: int) -> str: """Simply capture the child node without expanding it.""" s_exp = "\n" child_node: Node = cursor.node node_rep = child_node.text.decode("utf8") if node_rep in self.query_writer.template_holes: node_types = self.query_writer.template_holes[node_rep] alternations = " ".join([f"({node_type})" for node_type in node_types]) s_exp += ( " " * (depth + 1) + f"[{alternations}] @tag{self.query_writer.count}n" ) else: child_type = child_node.type s_exp += ( " " * (depth + 1) + f"({child_type}) @tag{self.query_writer.count}n" ) if child_node.child_count == 0: self.query_writer.query_ctrs.append( f"(#eq? @tag{self.query_writer.count}n \"{child_node.text.decode('utf8')}\")" ) self.query_writer.count += 1 return s_exp @attr.s class RegularStrategy(ChildProcessingStrategy): """ Strategy for the regular processing of a named child node. """ def process_child(self, cursor: TreeCursor, depth: int) -> str: """Write the query as if the child node was the root of the tree.""" prefix = ( f"{cursor.current_field_name()}: " if cursor.current_field_name() else "" ) return "\n" + self.query_writer.write_query( cursor.node, depth + 1, prefix, simplify=False ) @attr.s class QueryWriter: """ Class to represent a query writer for nodes. :ivar seq_nodes: Sequence of nodes for which the query will be written. :ivar template_holes: A dictionary to keep track of the nodes that will be replaced by a capture group. :ivar capture_groups: A dictionary to keep track of the nodes that will be captured in the query. :ivar count: A counter for naming the capture groups. :ivar query_str: The Comby pattern that represents the query. :ivar query_ctrs: List of constraints for the Comby pattern. :ivar outer_most_node: Represents the node that is currently the furthest from the root. """ seq_nodes = attr.ib(type=list) template_holes = attr.ib(type=List) capture_groups = attr.ib(default=attr.Factory(dict)) count = attr.ib(default=0) query_str = attr.ib(default="") query_ctrs = attr.ib(default=attr.Factory(list)) outer_most_node = attr.ib(default=None) strategy = attr.ib(default=None) def write(self, simplify=False): """ Get textual representation of the sequence. Find for each named child of source_node, can we replace it with its respective target group. :param simplify: If True, simplify the query. :return: The query string """ self.strategy = SimplifyStrategy(self) if simplify else RegularStrategy(self) self.query_str = "" self.query_ctrs = [] node_queries = [ self.write_query(node, simplify=simplify) for node in self.seq_nodes ] self.query_str = ".".join(node_queries) + "\n" + "\n".join(self.query_ctrs) self.query_str = f"({self.query_str})" return self.query_str def write_query(self, node: Node, depth=0, prefix="", simplify=False): """ Write a query for a given node, considering its depth and prefix. :param node: The node for which the query will be written. :param depth: The current depth of the node. :param prefix: Prefix for the current node. :param simplify: If True, simplify the query. :return: The query string for this node. """ indent = " " * depth cursor: TreeCursor = node.walk() self.count += 1 node_repr = node.text.decode("utf8") node_name = f"@tag{self.count}n" if node_repr in self.template_holes: # Fixed node node_types = self.template_holes[node_repr] if len(node_types) == 1: s_exp = indent + f"{prefix}({node_types[0]})" else: alternations = " ".join([f"({node_type})" for node_type in node_types]) s_exp = indent + f"{prefix}[{alternations}]" else: # Regular node node_type = node.type s_exp = indent + f"{prefix}({node_type} " next_child = cursor.goto_first_child() visited = 0 while next_child: if cursor.node.is_named: visited += 1 s_exp += self.strategy.process_child(cursor, depth) next_child = cursor.goto_next_sibling() # if the node is an identifier, add it to eq constraints if visited == 0: text = node.text.decode("utf8").replace("\n", " ") self.query_ctrs.append(f'(#eq? {node_name} "{text}")') s_exp += f")" self.capture_groups[node_name] = node self.outer_most_node = node_name return s_exp + f" {node_name}" def simplify_query(self, capture_group): """ Simplify a query removing all the children of capture_group and replacing it with a wildcard node. This should be replaced with Piranha at some point. :param capture_group: The capture group to simplify. """ comby = Comby() match = f"(:[[node_name]] :[_]) {capture_group}" rewrite = f"(:[[node_name]]) {capture_group}" self.query_str = comby.rewrite(self.query_str, match, rewrite) # Now for every child of capture_group, we need to remove equality checks from the query stack = [self.capture_groups[capture_group]] while stack: first = stack.pop() to_remove = next( key for key, value in self.capture_groups.items() if value == first ) match = f"(#eq? {to_remove} :[_])" self.query_str = comby.rewrite(self.query_str, match, "") self.capture_groups.pop(to_remove, None) for child in first.named_children: stack.append(child) def replace_with_tags(self, replace_str: str) -> str: """ Replace nodes with their corresponding capture group in the replace_str. :param replace_str: The string that needs replacement. :return: The replaced string. """ for capture_group, node in sorted( self.capture_groups.items(), key=lambda x: -len(x[1].text) ): if capture_group not in self.capture_groups.keys(): continue text_repr = NodeUtils.convert_to_source(node) if text_repr in replace_str: # self.simplify_query(capture_group) replace_str = replace_str.replace(text_repr, f"{capture_group}") return replace_str @attr.s class Inference: """ Class to represent inference on nodes. :ivar nodes_before: The list of nodes before the transformation. :ivar nodes_after: The list of nodes after the transformation. :ivar template_holes: The template holes for the inference. :ivar name: Name of the inference rule. :ivar _counter: A class-wide counter for naming the inference rules.""" nodes_before = attr.ib(type=List[Node], validator=attr.validators.instance_of(list)) nodes_after = attr.ib(type=List[Node], validator=attr.validators.instance_of(list)) template_holes = attr.ib(type=Dict[str, List[str]], default=attr.Factory(dict)) name = attr.ib( init=False, ) _counter = 0 def __attrs_post_init__(self): """ Initialization method to increment the counter and set the name for the rule. """ type(self)._counter += 1 self.name = f"rule_{type(self)._counter}" def static_infer(self) -> RawRule: """ Infer a raw rule based on the nodes before and after. :return: A raw rule inferred from the nodes.""" if len(self.nodes_after) > 0 and len(self.nodes_before) > 0: return self.create_rule(self.nodes_before, self.nodes_after) elif len(self.nodes_after) > 0: raise self.create_addition() elif len(self.nodes_before) > 0: return self.create_rule(self.nodes_before, []) def find_nodes_to_change(self, node_before: Node, node_after: Node): """ Function to find nodes to change if there's only one diverging node. :param node_before: The node before the change. :param node_after: The node after the change. :return: The nodes that need to be changed. """ diverging_nodes = [] if node_before.type == node_after.type: # Check if there's only one and only one diverging node # If there's more than one, then we can't do anything diverging_nodes = [ (child_before, child_after) for child_before, child_after in zip( node_before.named_children, node_after.named_children ) if NodeUtils.convert_to_source(child_before) != NodeUtils.convert_to_source(child_after) ] if ( len(diverging_nodes) == 1 and node_before.child_count == node_after.child_count ): return self.find_nodes_to_change(*diverging_nodes[0]) return node_before, node_after def create_rule(self, nodes_before: List[Node], nodes_after: List[Node]) -> RawRule: """ Create a rule based on the nodes before and after. :param nodes_before: The list of nodes before the change. :param nodes_after: The list of nodes after the change. :return: A raw rule representing the transformation from nodes_before to nodes_after. """ if len(nodes_before) == 1: if len(nodes_after) == 1: nodes_before[0], nodes_after[0] = self.find_nodes_to_change( nodes_before[0], nodes_after[0] ) node = nodes_before[0] qw = QueryWriter([node], self.template_holes) query = qw.write() lines_affected = " ".join( [NodeUtils.convert_to_source(node) for node in nodes_after] ) replacement_str = qw.replace_with_tags(lines_affected) return RawRule( name=self.name, query=query, replace_node=qw.outer_most_node[1:], replace=replacement_str, ) # If there are multiple nodes else: ancestor = NodeUtils.find_lowest_common_ancestor(nodes_before) replacement_str = NodeUtils.convert_to_source( ancestor, exclude=nodes_before ) replacement_str = replacement_str.replace( "{placeholder}", "", len(nodes_before) - 1 ) lines_affected = " ".join( [NodeUtils.convert_to_source(node) for node in nodes_after] ) replacement_str = replacement_str.replace( "{placeholder}", lines_affected, 1 ) qw = QueryWriter([ancestor], self.template_holes) qw.write() replacement_str = qw.replace_with_tags(replacement_str) return RawRule( name=self.name, query=qw.query_str, replace_node=qw.outer_most_node[1:], replace=replacement_str, ) def create_addition(self) -> str: """ A method to create addition rules. Currently not implemented. :raise: NotImplementedError""" raise NotImplementedError