in plugins/spark_upgrade/java_spark_context/__init__.py [0:0]
def __call__(self) -> dict[str, bool]:
if self.language != "java":
return {}
piranha_args = self.get_piranha_arguments()
summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args)
assert summaries is not None
for summary in summaries:
file_path: str = summary.path
match: tuple[str, Match]
for match in summary.matches:
if match[0] == "java_match_rule":
matched_str = match[1].matched_string
replace_str = matched_str.replace(
"new SparkConf()",
'SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")',
)
replace_str = replace_str.replace(".setAppName(", ".appName(")
replace_str = replace_str.replace(".setMaster(", ".master(")
replace_str = replace_str.replace(".set(", ".config(")
replace_str += ".getOrCreate().sparkContext()"
# assumes that there's only one match on the file
rewrite_rule = Rule(
name="rewrite_rule",
query=_NEW_SPARK_CONF_CHAIN_QUERY,
replace_node="mi",
replace=replace_str,
filters={
Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
},
)
rule_graph = RuleGraph(
rules=[rewrite_rule, _ADD_IMPORT_RULE],
edges=[
OutgoingEdges(
"rewrite_rule",
to=["add_import_rule"],
scope="File",
)
],
)
execute_piranha(
PiranhaArguments(
language=self.language,
rule_graph=rule_graph,
paths_to_codebase=[file_path],
)
)
if not summaries:
return {self.step_name(): False}
return {self.step_name(): True}