liminal/runners/airflow/tasks/spark.py (49 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from itertools import chain from flatdict import FlatDict from liminal.runners.airflow.tasks import containerable, hadoop class SparkTask(hadoop.HadoopTask, containerable.ContainerTask): """ Executes a Spark application. """ def __init__( self, task_id, dag, parent, trigger_rule, liminal_config, pipeline_config, task_config, variables=None ): task_config['image'] = task_config.get('image', '') task_config['cmd'] = task_config.get('cmd', []) task_config['env_vars'] = {'SPARK_LOCAL_HOSTNAME': 'localhost'} super().__init__(task_id, dag, parent, trigger_rule, liminal_config, pipeline_config, task_config, variables) def get_runnable_command(self): """ Return spark-submit runnable command """ return self.__generate_spark_submit() def __generate_spark_submit(self): spark_submit = ['spark-submit'] spark_arguments = self.__spark_args() application_arguments = self.__additional_arguments() spark_submit.extend(spark_arguments) spark_submit.extend(application_arguments) return [str(x) for x in spark_submit] def __spark_args(self): # reformat spark conf flat_conf_args = list() spark_arguments = { 'master': self.task_config.get('master', None), 'class': self.task_config.get('class', None), } source_code = self.task_config.get("application_source") for conf_arg in [f'{k}={v}' for (k, v) in FlatDict(self.task_config.get('conf', {})).items()]: flat_conf_args.append('--conf') flat_conf_args.append(conf_arg) spark_conf = self.__parse_spark_arguments(spark_arguments) spark_conf.extend(flat_conf_args) spark_conf.extend([source_code]) return spark_conf def __additional_arguments(self): application_arguments = self.task_config.get('application_arguments', {}) if type(application_arguments) == list: return application_arguments return self.__interleaving(application_arguments.keys(), application_arguments.values()) def __parse_spark_arguments(self, spark_arguments): spark_arguments = {x[0]: x[1] for x in spark_arguments.items() if x[1]} return self.__interleaving( [f'--{k}' for k in spark_arguments.keys() if spark_arguments[k]], spark_arguments.values() ) @staticmethod def __interleaving(keys, values): return list(chain.from_iterable(zip(keys, values))) def _kubernetes_cmds_and_arguments(self): return self.__generate_spark_submit(), []