plugins/spark_upgrade/scala_session_builder/__init__.py (96 lines of code) (raw):
# Copyright (c) 2024 Uber Technologies, Inc.
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, List, Dict
from execute_piranha import ExecutePiranha
from polyglot_piranha import (
execute_piranha,
Filter,
OutgoingEdges,
Rule,
PiranhaOutputSummary,
Match,
Edit,
PiranhaArguments,
RuleGraph,
)
VAL_DEF_QUERY = """(
(val_definition
pattern: (identifier) @val_id
type: (type_identifier) @type_id
value: (call_expression
function: (identifier) @func_call
)
(#eq? @type_id "SparkSession")
(#eq? @func_call "spy")
) @val_def
)"""
QUERY = f"""(
(function_definition
body: (block
{VAL_DEF_QUERY}
.
(call_expression
function: (field_expression
value: (field_expression
value: (identifier) @lhs
field: (identifier) @rhs
)
field: (identifier) @call_name
(#eq? @lhs @val_id)
(#eq? @rhs "sqlContext")
(#eq? @call_name "setConf")
)
)+ @calls
)
) @func_def
)"""
class ScalaSessionBuilder(ExecutePiranha):
def __init__(self, paths_to_codebase: List[str], language: str = "scala"):
super().__init__(
paths_to_codebase=paths_to_codebase,
substitutions={},
language=language,
)
def __call__(self) -> dict[str, bool]:
if self.language != "scala":
return {}
piranha_args = self.get_piranha_arguments()
summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args)
assert summaries is not None
for summary in summaries:
file_path: str = summary.path
edit: Edit
if len(summary.rewrites) == 0:
continue
print(f"rewrites: {len(summary.rewrites)}")
calls_to_add_str = ""
# the rewrite's edit will have `calls` with all matches
edit = summary.rewrites[0]
if edit.matched_rule == "delete_calls_query":
match: Match = edit.p_match
val_id = match.matches["val_id"]
calls = match.matches["calls"]
print(f"calls: {calls}")
calls_to_add_str = calls.replace(
f"{val_id}.sqlContext.setConf", ".config"
)
match = summary.rewrites[0].p_match
val_def = match.matches["val_def"]
assert isinstance(val_def, str)
assert "getOrCreate()" in val_def
replace_str = calls_to_add_str + "\n.getOrCreate()"
new_val_def = val_def.replace(".getOrCreate()", replace_str)
replace_val_def_rule = Rule(
name="replace_val_def_rule",
query=VAL_DEF_QUERY,
replace_node="val_def",
replace=new_val_def,
filters={
Filter(
enclosing_node="(val_definition) @_vl_def",
not_contains=(
[
"""(
(identifier) @conf_id
(#eq? @conf_id "config")
)"""
]
),
)
},
)
rule_graph = RuleGraph(
rules=[replace_val_def_rule],
edges=[],
)
execute_piranha(
PiranhaArguments(
language=self.language,
rule_graph=rule_graph,
paths_to_codebase=[file_path],
)
)
if not summaries:
return {self.step_name(): False}
return {self.step_name(): True}
def step_name(self) -> str:
return "Spark spy SessionBuilder"
def get_rules(self) -> List[Rule]:
if self.language != "scala":
return []
delete_calls_query = Rule(
name="delete_calls_query",
query=QUERY,
replace_node="calls",
replace="",
)
return [delete_calls_query]
def get_edges(self) -> List[OutgoingEdges]:
return []
def summaries_to_custom_dict(self, _) -> Dict[str, Any]:
return {}