experimental/piranha_playground/rule_inference/utils/rule_utils.py (153 lines of code) (raw):

# Copyright (c) 2023 Uber Technologies, Inc. # # <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file # except in compliance with the License. You may obtain a copy of the License at # <p>http://www.apache.org/licenses/LICENSE-2.0 # # <p>Unless required by applicable law or agreed to in writing, software distributed under the # License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing permissions and # limitations under the License. import json from typing import Dict, List, Set import attr import toml from polyglot_piranha import Filter, OutgoingEdges, Rule, RuleGraph @attr.s(eq=False) class RawFilter: id = attr.ib(type=str, default="0") enclosing_node = attr.ib(type=str, default=None) not_enclosing_node = attr.ib(type=str, default=None) not_contains = attr.ib(type=List[str], default=[]) contains = attr.ib(type=str, default=None) at_least = attr.ib(type=int, default=1) at_most = attr.ib(type=int, default=4294967295) # u32::MAX child_count = attr.ib(type=int, default=4294967295) # u32::MAX sibling_count = attr.ib(type=int, default=4294967295) # u32::MAX def to_toml(self): str_reps = [] if self.enclosing_node: str_reps.append(f'enclosing_node = """{self.enclosing_node}"""') if self.not_enclosing_node: str_reps.append(f'not_enclosing_node = """{self.not_enclosing_node}"""') if self.not_contains: str_reps.append(f"not_contains = {json.dumps(self.not_contains)}") if self.contains: str_reps.append(f'contains = """{self.contains}"""') if self.at_least != 1: str_reps.append(f"at_least = {self.at_least}") if self.at_most != 4294967295: str_reps.append(f"at_most = {self.at_most}") if self.child_count != 4294967295: str_reps.append(f"child_count = {self.child_count}") if self.sibling_count != 4294967295: str_reps.append(f"sibling_count = {self.sibling_count}") return "\n".join([f"[[rules.filters]]"] + str_reps) def to_filter(self): return Filter( enclosing_node=self.enclosing_node, not_enclosing_node=self.not_enclosing_node, not_contains=self.not_contains, contains=self.contains, at_least=self.at_least, at_most=self.at_most, # child_count=self.child_count, # sibling_count=self.sibling_count, ) @staticmethod def from_toml(toml_dict: Dict) -> "RawFilter": return RawFilter( enclosing_node=toml_dict.get("enclosing_node", None), not_enclosing_node=toml_dict.get("not_enclosing_node", None), not_contains=toml_dict.get("not_contains", []), contains=toml_dict.get("contains", None), at_least=toml_dict.get("at_least", 1), at_most=toml_dict.get("at_most", 4294967295), child_count=toml_dict.get("child_count", 4294967295), sibling_count=toml_dict.get("sibling_count", 4294967295), ) @attr.s(eq=False) class RawRule: name = attr.ib(type=str) query = attr.ib(type=str, default=None) replace_node = attr.ib(type=str, default=None) replace = attr.ib(type=str, default=None) groups = attr.ib(type=Set[str], default=set()) holes = attr.ib(type=Set[str], default=set()) filters = attr.ib(type=Set[RawFilter], default=set()) is_seed_rule = attr.ib(type=bool, default=True) def to_toml(self): str_reprs = [f'name = "{self.name}"'] if self.query: str_reprs.append(f'query = """{self.query}"""') if self.replace_node: str_reprs.append(f'replace_node = "{self.replace_node}"') if self.replace: str_reprs.append(f'replace = """{self.replace}"""') if self.groups: str_reprs.append(f"groups = {json.dumps(self.groups)}") if self.holes: str_reprs.append(f"holes = {json.dumps(self.holes)}") if self.filters: str_reprs.append(self._filters_to_toml()) if not self.is_seed_rule: str_reprs.append(f"is_seed_rule = {str(self.is_seed_rule).lower()}") return "[[rules]]\n" + "\n".join(str_reprs) def _filters_to_toml(self): return "\n".join([filter.to_toml() for filter in self.filters]) def to_rule(self): return Rule( name=self.name, query=self.query, replace_node=self.replace_node, replace=self.replace, groups=self.groups, holes=self.holes, filters=set([f.to_filter() for f in self.filters]), is_seed_rule=self.is_seed_rule, ) @staticmethod def from_toml(toml_dict: Dict) -> "RawRule": return RawRule( name=toml_dict["name"], query=toml_dict.get("query", None), replace_node=toml_dict.get("replace_node", None), replace=toml_dict.get("replace", None), groups=set(toml_dict.get("groups", set())), holes=set(toml_dict.get("holes", set())), filters=set([RawFilter.from_toml(f) for f in toml_dict.get("filters", [])]), is_seed_rule=toml_dict.get("is_seed_rule", True), ) @attr.s class RawRuleGraph: rules = attr.ib(type=List[RawRule], validator=attr.validators.instance_of(list)) edges = attr.ib(type=List[Dict], validator=attr.validators.instance_of(list)) def to_toml(self): rules_str = "\n\n".join(rule.to_toml() for rule in self.rules) edges_str = "\n\n" if self.edges: edges_str = "\n\n".join(self.edge_to_toml(edge) for edge in self.edges) return f"{rules_str}\n{edges_str}" def to_graph(self): return RuleGraph( [rule.to_rule() for rule in self.rules], [ OutgoingEdges(edge["from"], edge["to"], edge["scope"]) for edge in self.edges ], ) @staticmethod def edge_to_toml(edge_dict: Dict) -> str: return "\n".join( [ "[[edges]]", f'scope = "{edge_dict["scope"]}"', f'from = "{edge_dict["from"]}"', f"to = {json.dumps(edge_dict['to'])}", ] ) @staticmethod def from_toml(toml_dict: Dict) -> "RawRuleGraph": rules = [] for toml_rule in toml_dict["rules"]: rule = RawRule.from_toml(toml_rule) rules.append(rule) edges = toml_dict.get("edges", []) return RawRuleGraph(rules=rules, edges=edges) @staticmethod def validate(instance, attribute, value): """ Validator function for TOML input. If value is not valid TOML, raises ValueError. :param instance: The instance of the class the attribute is attached to. :param attribute: The attribute that this validator function is checking. :param value: The value of the attribute that is being validated. :raises ValueError: If value is not valid TOML. """ try: toml_dict = toml.loads(value) RawRuleGraph.from_toml(toml_dict) except Exception as e: raise ValueError( "Invalid rule format. Please refer to the rule format in the README" ) from e