plugins/spark_upgrade/execute_piranha.py (89 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Union, List, Dict
from polyglot_piranha import PiranhaArguments, execute_piranha, Rule, RuleGraph, OutgoingEdges
class ExecutePiranha(ABC):
'''
This abstract class implements the higher level strategy for
applying a specific polyglot piranha configuration i.e. rules/edges.
'''
def __init__(self, paths_to_codebase: List[str], language: str, substitutions: Dict[str, str], dry_run=False, allow_dirty_ast=False):
self.paths_to_codebase = paths_to_codebase
self.language = language
self.substitutions = substitutions
self.dry_run = dry_run
self.allow_dirty_ast = allow_dirty_ast
def __call__(self) -> dict:
piranha_args = self.get_piranha_arguments()
self.summaries = execute_piranha(piranha_args)
output = self.summaries_to_custom_dict(self.summaries)
success = True
if not output:
success = False
output = {}
output[self.step_name()] = success
return output
@abstractmethod
def step_name(self) -> str:
'''
The overriding method should return the name of the strategy.
'''
...
@abstractmethod
def get_rules(self) -> List[Rule]:
'''
The list of rules.
'''
...
def get_edges(self) -> List[OutgoingEdges]:
'''
The list of edges.
'''
return []
def get_rule_graph(self) -> RuleGraph:
'''
Strategy to construct a rule graph from rules/edges.
'''
return RuleGraph(rules=self.get_rules(), edges=self.get_edges())
def path_to_configuration(self) -> Union[None, str]:
'''
Path to rules/edges toml file (incase rule graph is not specified).
'''
return None
def get_piranha_arguments(self) -> PiranhaArguments:
rg = self.get_rule_graph()
path_to_conf = self.path_to_configuration()
if rg.rules and path_to_conf:
raise Exception(
"You have provided a rule graph and path to configurations. Do not provide both.")
if not rg.rules and not path_to_conf:
raise Exception("You have neither provided a rule graph nor path to configurations.")
if rg.rules:
return PiranhaArguments(
language=self.language,
paths_to_codebase=self.paths_to_codebase,
substitutions=self.substitutions,
rule_graph=self.get_rule_graph(),
cleanup_comments=True,
dry_run=self.dry_run,
allow_dirty_ast=self.allow_dirty_ast
)
return PiranhaArguments(
language=self.language,
paths_to_codebase=self.paths_to_codebase,
substitutions=self.substitutions,
path_to_configurations=self.path_to_configuration(),
cleanup_comments=True,
dry_run=self.dry_run,
allow_dirty_ast=self.allow_dirty_ast
)
def get_matches(self, specified_rule: str) -> List[dict]:
"""
This function gets matches for a specified rule.
"""
return [match.matches
for summary in self.summaries
for actual_rule, match in summary.matches if specified_rule == actual_rule]
@abstractmethod
def summaries_to_custom_dict(self, _) -> Dict[str, Any]:
'''
The overriding method should implement the logic for extracting out the
useful information from the matches/rewrites reported by polyglot piranha into a dict.
'''
...