experimental/piranha_playground/rule_inference/template_parser.py (74 lines of code) (raw):

import re from collections import defaultdict from typing import Dict, List, Optional, Tuple import attr from polyglot_piranha import PiranhaArguments, Rule, RuleGraph, execute_piranha from tree_sitter import Tree from tree_sitter_languages import get_language, get_parser WILDCARD = "_" @attr.s class TemplateParser: """This class provides parsing utilities for piranha's template language""" language = attr.ib(default="java") tree_sitter_language = attr.ib(default=None) parser = attr.ib(default=None) language_mappings = { "java": "java", "kt": "kotlin", } # This is necessary because get_parser and piranha expect different naming conventions comments_node_names = { "java": "line_comment", "go": "comment", "kt": "comment", "swift": "comment", } template_holes = attr.ib(default=attr.Factory(dict)) def __attrs_post_init__(self): """ Initialize parser and language attributes for the given language after the agent object is created. """ self.tree_sitter_language = get_language( self.language_mappings.get(self.language, self.language) ) self.parser = get_parser( self.language_mappings.get(self.language, self.language) ) def get_tree_from_code(self, code: str, remove_comments: bool = False) -> Tree: """ Parse the given code and return its abstract syntax tree (AST). :param code: The source code to parse :param remove_comments: Whether to remove comments from the code before parsing :return: AST of the source code """ code = self.replace_template_holes(code) if remove_comments: code = self.remove_comments_from_code(code) tree = self.parser.parse(bytes(code, "utf8")) return tree def remove_comments_from_code(self, code: str) -> str: """ Removes all comments from the given code using Piranha. :param code: The source code from which to remove comments :type code: str :return: Source code without comments :rtype: str """ rule = Rule( name="remove_comments", query=f"({self.comments_node_names[self.language]}) @comment", replace_node="comment", replace="", ) graph = RuleGraph(rules=[rule], edges=[]) args = PiranhaArguments( code_snippet=code, language=self.language, rule_graph=graph, dry_run=True, ) output_summaries = execute_piranha(args) if output_summaries: return output_summaries[0].content return code def replace_template_holes(self, code: str) -> str: """ Replace the template holes in source and target with identifiers that can be parsed by tree sitter. :param code: The source code to parse :return: mapping of template holes to identifiers, and corresponding types """ template_pattern = re.compile(r":\[(?P<content>[^[]*)\]") matches = template_pattern.finditer(code) template_holes = {} replaced_code = code for match in matches: name, alternations = self.parse_content(match.group("content")) template_holes[name] = alternations or [WILDCARD] # replace match with identifier replaced_code = replaced_code.replace(match.group(), name) self.template_holes.update(template_holes) return replaced_code @staticmethod def parse_content(content: str) -> Tuple[str, Optional[List[str]]]: if ":" in content: name, alternations = [x.strip() for x in content.split(":", 1)] alternations = [x.strip() for x in alternations.split(",")] return name, alternations else: return content, None