plugins/spark_upgrade/main.py (72 lines of code) (raw):
# Copyright (c) 2023 Uber Technologies, Inc.
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import glob
from update_calendar_interval import UpdateCalendarInterval
from IDF_model_signature_change import IDFModelSignatureChange
from accessing_execution_plan import AccessingExecutionPlan
from gradient_boost_trees import GradientBoostTrees
from calculator_signature_change import CalculatorSignatureChange
from sql_new_execution import SQLNewExecutionChange
from query_test_check_answer_change import QueryTestCheckAnswerChange
from spark_config import SparkConfigChange
from java_spark_context import JavaSparkContextChange
from scala_session_builder import ScalaSessionBuilder
def _parse_args():
parser = argparse.ArgumentParser(
description="Updates the codebase to use a new version of `spark3`"
)
parser.add_argument(
"--path_to_codebase",
required=True,
help="Path to the codebase directory.",
)
parser.add_argument(
"--new_version",
default="3.3",
help="Version of `Spark` to update to.",
)
args = parser.parse_args()
return args
FORMAT = "%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s"
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.DEBUG)
def main():
args = _parse_args()
if args.new_version == "3.3":
upgrade_to_spark_3_3(args.path_to_codebase)
def upgrade_to_spark_3_3(path_to_codebase: str):
"""Wraps calls to Piranha with try/except to prevent it failing on a single file.
We catch `BaseException`, as pyo3 `PanicException` extends it."""
for scala_file in glob.glob(f"{path_to_codebase}/**/*.scala", recursive=True):
try:
update_file(scala_file)
except BaseException as e:
logging.error(f"Error running for file file {scala_file}: {e}")
for java_file in glob.glob(f"{path_to_codebase}/**/*.java", recursive=True):
try:
update_file(java_file)
except BaseException as e:
logging.error(f"Error running for file file {java_file}: {e}")
def update_file(file_path: str):
update_calendar_interval = UpdateCalendarInterval([file_path])
_ = update_calendar_interval()
idf_model_signature_change = IDFModelSignatureChange([file_path])
_ = idf_model_signature_change()
accessing_execution_plan = AccessingExecutionPlan([file_path])
_ = accessing_execution_plan()
gradient_boost_trees = GradientBoostTrees([file_path])
_ = gradient_boost_trees()
calculator_signature_change = CalculatorSignatureChange([file_path])
_ = calculator_signature_change()
sql_new_execution = SQLNewExecutionChange([file_path])
_ = sql_new_execution()
query_test_check_answer_change = QueryTestCheckAnswerChange([file_path])
_ = query_test_check_answer_change()
spark_config = SparkConfigChange([file_path])
_ = spark_config()
spark_config = SparkConfigChange([file_path], language="java")
_ = spark_config()
javasparkcontext = JavaSparkContextChange([file_path], language="java")
_ = javasparkcontext()
scalasessionbuilder = ScalaSessionBuilder([file_path], language="scala")
_ = scalasessionbuilder()
if __name__ == "__main__":
main()