experimental/piranha_playground/rule_inference/utils/rule_utils.py (153 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Dict, List, Set
import attr
import toml
from polyglot_piranha import Filter, OutgoingEdges, Rule, RuleGraph
@attr.s(eq=False)
class RawFilter:
id = attr.ib(type=str, default="0")
enclosing_node = attr.ib(type=str, default=None)
not_enclosing_node = attr.ib(type=str, default=None)
not_contains = attr.ib(type=List[str], default=[])
contains = attr.ib(type=str, default=None)
at_least = attr.ib(type=int, default=1)
at_most = attr.ib(type=int, default=4294967295) # u32::MAX
child_count = attr.ib(type=int, default=4294967295) # u32::MAX
sibling_count = attr.ib(type=int, default=4294967295) # u32::MAX
def to_toml(self):
str_reps = []
if self.enclosing_node:
str_reps.append(f'enclosing_node = """{self.enclosing_node}"""')
if self.not_enclosing_node:
str_reps.append(f'not_enclosing_node = """{self.not_enclosing_node}"""')
if self.not_contains:
str_reps.append(f"not_contains = {json.dumps(self.not_contains)}")
if self.contains:
str_reps.append(f'contains = """{self.contains}"""')
if self.at_least != 1:
str_reps.append(f"at_least = {self.at_least}")
if self.at_most != 4294967295:
str_reps.append(f"at_most = {self.at_most}")
if self.child_count != 4294967295:
str_reps.append(f"child_count = {self.child_count}")
if self.sibling_count != 4294967295:
str_reps.append(f"sibling_count = {self.sibling_count}")
return "\n".join([f"[[rules.filters]]"] + str_reps)
def to_filter(self):
return Filter(
enclosing_node=self.enclosing_node,
not_enclosing_node=self.not_enclosing_node,
not_contains=self.not_contains,
contains=self.contains,
at_least=self.at_least,
at_most=self.at_most,
# child_count=self.child_count,
# sibling_count=self.sibling_count,
)
@staticmethod
def from_toml(toml_dict: Dict) -> "RawFilter":
return RawFilter(
enclosing_node=toml_dict.get("enclosing_node", None),
not_enclosing_node=toml_dict.get("not_enclosing_node", None),
not_contains=toml_dict.get("not_contains", []),
contains=toml_dict.get("contains", None),
at_least=toml_dict.get("at_least", 1),
at_most=toml_dict.get("at_most", 4294967295),
child_count=toml_dict.get("child_count", 4294967295),
sibling_count=toml_dict.get("sibling_count", 4294967295),
)
@attr.s(eq=False)
class RawRule:
name = attr.ib(type=str)
query = attr.ib(type=str, default=None)
replace_node = attr.ib(type=str, default=None)
replace = attr.ib(type=str, default=None)
groups = attr.ib(type=Set[str], default=set())
holes = attr.ib(type=Set[str], default=set())
filters = attr.ib(type=Set[RawFilter], default=set())
is_seed_rule = attr.ib(type=bool, default=True)
def to_toml(self):
str_reprs = [f'name = "{self.name}"']
if self.query:
str_reprs.append(f'query = """{self.query}"""')
if self.replace_node:
str_reprs.append(f'replace_node = "{self.replace_node}"')
if self.replace:
str_reprs.append(f'replace = """{self.replace}"""')
if self.groups:
str_reprs.append(f"groups = {json.dumps(self.groups)}")
if self.holes:
str_reprs.append(f"holes = {json.dumps(self.holes)}")
if self.filters:
str_reprs.append(self._filters_to_toml())
if not self.is_seed_rule:
str_reprs.append(f"is_seed_rule = {str(self.is_seed_rule).lower()}")
return "[[rules]]\n" + "\n".join(str_reprs)
def _filters_to_toml(self):
return "\n".join([filter.to_toml() for filter in self.filters])
def to_rule(self):
return Rule(
name=self.name,
query=self.query,
replace_node=self.replace_node,
replace=self.replace,
groups=self.groups,
holes=self.holes,
filters=set([f.to_filter() for f in self.filters]),
is_seed_rule=self.is_seed_rule,
)
@staticmethod
def from_toml(toml_dict: Dict) -> "RawRule":
return RawRule(
name=toml_dict["name"],
query=toml_dict.get("query", None),
replace_node=toml_dict.get("replace_node", None),
replace=toml_dict.get("replace", None),
groups=set(toml_dict.get("groups", set())),
holes=set(toml_dict.get("holes", set())),
filters=set([RawFilter.from_toml(f) for f in toml_dict.get("filters", [])]),
is_seed_rule=toml_dict.get("is_seed_rule", True),
)
@attr.s
class RawRuleGraph:
rules = attr.ib(type=List[RawRule], validator=attr.validators.instance_of(list))
edges = attr.ib(type=List[Dict], validator=attr.validators.instance_of(list))
def to_toml(self):
rules_str = "\n\n".join(rule.to_toml() for rule in self.rules)
edges_str = "\n\n"
if self.edges:
edges_str = "\n\n".join(self.edge_to_toml(edge) for edge in self.edges)
return f"{rules_str}\n{edges_str}"
def to_graph(self):
return RuleGraph(
[rule.to_rule() for rule in self.rules],
[
OutgoingEdges(edge["from"], edge["to"], edge["scope"])
for edge in self.edges
],
)
@staticmethod
def edge_to_toml(edge_dict: Dict) -> str:
return "\n".join(
[
"[[edges]]",
f'scope = "{edge_dict["scope"]}"',
f'from = "{edge_dict["from"]}"',
f"to = {json.dumps(edge_dict['to'])}",
]
)
@staticmethod
def from_toml(toml_dict: Dict) -> "RawRuleGraph":
rules = []
for toml_rule in toml_dict["rules"]:
rule = RawRule.from_toml(toml_rule)
rules.append(rule)
edges = toml_dict.get("edges", [])
return RawRuleGraph(rules=rules, edges=edges)
@staticmethod
def validate(instance, attribute, value):
"""
Validator function for TOML input. If value is not valid TOML, raises ValueError.
:param instance: The instance of the class the attribute is attached to.
:param attribute: The attribute that this validator function is checking.
:param value: The value of the attribute that is being validated.
:raises ValueError: If value is not valid TOML.
"""
try:
toml_dict = toml.loads(value)
RawRuleGraph.from_toml(toml_dict)
except Exception as e:
raise ValueError(
"Invalid rule format. Please refer to the rule format in the README"
) from e