experimental/piranha_playground/rule_inference/rule_application.py (118 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import logging
import multiprocessing
from typing import Dict, List, Optional, Tuple
import attr
import toml
from piranha_playground.rule_inference.utils.logger_formatter import CustomFormatter
from piranha_playground.rule_inference.utils.rule_utils import RawRuleGraph
from polyglot_piranha import (
PiranhaArguments,
PiranhaOutputSummary,
Rule,
RuleGraph,
execute_piranha,
)
logger = logging.getLogger("CodebaseRefactorer")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
class CodebaseRefactorerException(Exception):
"""
Exception class for CodebaseRefactorer.
"""
pass
def enable_piranha_logs():
"""
Sets up the logging configurations for Piranha.
"""
FORMAT = "%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s"
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.DEBUG)
def flatten_dict_list(d: Dict):
for k, v in d.items():
if isinstance(v, list) and len(v) == 1:
d[k] = v[0]
elif isinstance(v, dict):
flatten_dict_list(v)
elif isinstance(v, list):
for item in v:
if isinstance(item, dict):
flatten_dict_list(item)
return d
def _run_piranha_with_timeout_aux(
source_code: str,
language: str,
raw_graph: RawRuleGraph,
timeout: int = 0,
**kwargs,
):
"""
Private method to run Piranha with a timeout. Executes Piranha with provided arguments.
:param source_code: str: The source code to be refactored.
:param language: str: The language of the source code.
:param raw_graph: RawRuleGraph: The rule graph to be used for refactoring.
:param substitutions: dict: The substitutions to be made during refactoring.
:return: Tuple[str, bool]: The refactored code and a boolean indicating if the execution was successful.
"""
try:
# Prepare arguments for Piranha execution
args = PiranhaArguments(
code_snippet=source_code,
language=language,
rule_graph=raw_graph.to_graph(),
dry_run=True,
**kwargs,
)
piranha_results = execute_piranha(args)
# Check if the execution returns results, if yes then return the content of the first result
# Otherwise, return an empty list
if piranha_results:
return piranha_results[0].content, True
return source_code, True
except BaseException as e:
return str(e), False
def run_piranha_with_timeout(
source_code: str,
language: str,
raw_graph: RawRuleGraph,
substitutions: Optional[dict] = None,
timeout: Optional[int] = 10,
) -> Tuple[str, bool]:
"""
Executes Piranha with a timeout. Calls a private method to perform the execution and terminates if the timeout is reached.
:param source_code: str: The source code to be refactored.
:param language: str: The language of the source code.
:param raw_graph: RawRuleGraph: The rule graph to be used for refactoring.
:param substitutions: Optional[dict]: The substitutions to be made during refactoring. Default is None.
:param timeout: Optional[int]: The timeout for the execution in seconds. Default is 10 seconds.
:return: Tuple[str, bool]: The refactored code and a boolean indicating if the execution was successful.
"""
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(
_run_piranha_with_timeout_aux,
(source_code, language, raw_graph, substitutions),
)
return async_result.get(timeout=timeout)
@attr.s
class CodebaseRefactorer:
"""
A class that uses Piranha to refactor an entire codebase based on rules specified in a .toml file.
"""
language = attr.ib(type=str)
path_to_codebase = attr.ib(type=str)
rules = attr.ib(type=str)
include_paths = attr.ib(type=List[str], default=None)
exclude_paths = attr.ib(type=List[str], default=None)
def refactor_codebase(self, dry_run: bool = True) -> List[PiranhaOutputSummary]:
"""
Applies the refactoring rules to the codebase.
:param dry_run: bool: A boolean that if true, runs the refactor without making actual changes. Default is True.
:return: List[PiranhaOutputSummary]: A list of summaries of the changes made by Piranha.
:raises CodebaseRefactorerException: If the refactoring fails.
"""
try:
toml_dict = toml.loads(self.rules)
rule_graph = RawRuleGraph.from_toml(toml_dict)
arguments = toml_dict.get("arguments", [{}])[0]
arguments = flatten_dict_list(arguments)
args = PiranhaArguments(
language=self.language,
paths_to_codebase=[self.path_to_codebase],
rule_graph=rule_graph.to_graph(),
dry_run=dry_run,
**arguments,
)
output_summaries = execute_piranha(args)
logger.info("Changed files:")
for summary in output_summaries:
logger.info(summary.path)
return output_summaries
except BaseException as e:
raise CodebaseRefactorerException(str(e)) from e
@staticmethod
def refactor_snippet(source_code: str, language: str, rules: str) -> str:
"""
Refactors a code snippet based on the provided rules.
:param source_code: str: The source code to be refactored.
:param language: str: The language of the source code.
:param rules: str: The refactoring rules in a .toml format.
:return: str: The refactored code or error message.
:raises CodebaseRefactorerException: If the refactoring fails.
"""
try:
toml_dict = toml.loads(rules)
arguments = toml_dict.get("arguments", [{}])[0]
arguments = flatten_dict_list(arguments)
refactored_code, success = run_piranha_with_timeout(
source_code,
language,
RawRuleGraph.from_toml(toml_dict),
timeout=5,
**arguments,
)
return refactored_code
except multiprocessing.context.TimeoutError as e:
raise CodebaseRefactorerException(
"Piranha is likely in an infinite loop. Please check your rules."
) from e
except Exception as e:
raise CodebaseRefactorerException(str(e)) from e