def __call__()

in plugins/spark_upgrade/java_spark_context/__init__.py [0:0]


    def __call__(self) -> dict[str, bool]:
        if self.language != "java":
            return {}

        piranha_args = self.get_piranha_arguments()
        summaries: list[PiranhaOutputSummary] = execute_piranha(piranha_args)
        assert summaries is not None

        for summary in summaries:
            file_path: str = summary.path
            match: tuple[str, Match]
            for match in summary.matches:
                if match[0] == "java_match_rule":
                    matched_str = match[1].matched_string

                    replace_str = matched_str.replace(
                        "new SparkConf()",
                        'SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")',
                    )
                    replace_str = replace_str.replace(".setAppName(", ".appName(")
                    replace_str = replace_str.replace(".setMaster(", ".master(")
                    replace_str = replace_str.replace(".set(", ".config(")
                    replace_str += ".getOrCreate().sparkContext()"

                    # assumes that there's only one match on the file
                    rewrite_rule = Rule(
                        name="rewrite_rule",
                        query=_NEW_SPARK_CONF_CHAIN_QUERY,
                        replace_node="mi",
                        replace=replace_str,
                        filters={
                            Filter(enclosing_node=_JAVASPARKCONTEXT_OCE_QUERY),
                        },
                    )

                    rule_graph = RuleGraph(
                        rules=[rewrite_rule, _ADD_IMPORT_RULE],
                        edges=[
                            OutgoingEdges(
                                "rewrite_rule",
                                to=["add_import_rule"],
                                scope="File",
                            )
                        ],
                    )
                    execute_piranha(
                        PiranhaArguments(
                            language=self.language,
                            rule_graph=rule_graph,
                            paths_to_codebase=[file_path],
                        )
                    )

        if not summaries:
            return {self.step_name(): False}

        return {self.step_name(): True}