plugins/spark_upgrade/java_spark_context/__init__.py (104 lines of code) (raw):
# Copyright (c) 2024 Uber Technologies, Inc.
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, List, Dict
from execute_piranha import ExecutePiranha
from polyglot_piranha import (
execute_piranha,
Filter,
OutgoingEdges,
Rule,
PiranhaOutputSummary,
Match,
PiranhaArguments,
RuleGraph,
)
_JAVASPARKCONTEXT_OCE_QUERY = """(
(object_creation_expression
type: (_) @oce_typ
(#eq? @oce_typ "JavaSparkContext")
) @oce
)"""
_NEW_SPARK_CONF_CHAIN_QUERY = """(
(argument_list
.
(method_invocation) @mi
.
(#match? @mi "^new SparkConf()\\.")
)
)""" # matches a chain of method invocations starting with `new SparkConf().`; the chain is the only argument of an argument_list (indicated by the surrounding anchors `.`).
# Note that we don't remove the unused `SparkConf` import; that will be automated somewhere else.
_ADD_IMPORT_RULE = Rule(
name="add_import_rule",
query="""(
(program
(import_declaration) @imp_decl
)
)""", # matches the last import
replace_node="imp_decl",
replace="@imp_decl\nimport org.apache.spark.sql.SparkSession;",
is_seed_rule=False,
filters={
Filter( # avoids infinite loop
enclosing_node="((program) @unit)",
not_contains=[("cs import org.apache.spark.sql.SparkSession;")],
),
},
)
class JavaSparkContextChange(ExecutePiranha):
def __init__(self, paths_to_codebase: List[str], language: str = "java"):
super().__init__(
paths_to_codebase=paths_to_codebase,
substitutions={
"spark_conf": "SparkConf",
},
language=language,
)
def __call__(self) -> dict[str, bool]:
if self.language != "java":
return {}
piranha_args = self.get_piranha_arguments()
summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args)
assert summaries is not None
for summary in summaries:
file_path: str = summary.path
match: tuple[str, Match]
for match in summary.matches:
if match[0] == "java_match_rule":
matched_str = match[1].matched_string
replace_str = matched_str.replace(
"new SparkConf()",
'SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")',
)
replace_str = replace_str.replace(".setAppName(", ".appName(")
replace_str = replace_str.replace(".setMaster(", ".master(")
replace_str = replace_str.replace(".set(", ".config(")
replace_str += ".getOrCreate().sparkContext()"
# assumes that there's only one match on the file
rewrite_rule = Rule(
name="rewrite_rule",
query=_NEW_SPARK_CONF_CHAIN_QUERY,
replace_node="mi",
replace=replace_str,
filters={
Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
},
)
rule_graph = RuleGraph(
rules=[rewrite_rule, _ADD_IMPORT_RULE],
edges=[
OutgoingEdges(
"rewrite_rule",
to=["add_import_rule"],
scope="File",
)
],
)
execute_piranha(
PiranhaArguments(
language=self.language,
rule_graph=rule_graph,
paths_to_codebase=[file_path],
)
)
if not summaries:
return {self.step_name(): False}
return {self.step_name(): True}
def step_name(self) -> str:
return "JavaSparkContext Change"
def get_rules(self) -> List[Rule]:
if self.language != "java":
return []
java_match_rule = Rule(
name="java_match_rule",
query=_NEW_SPARK_CONF_CHAIN_QUERY,
filters={
Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
},
)
return [java_match_rule]
def get_edges(self) -> List[OutgoingEdges]:
return []
def summaries_to_custom_dict(self, _) -> Dict[str, Any]:
return {}