o2a/mappers/spark_mapper.py (96 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maps Spark action to Airflow Dag"""
from typing import Dict, List, Optional, Set
import xml.etree.ElementTree as ET
from o2a.converter.exceptions import ParseException
from o2a.converter.task import Task
from o2a.converter.relation import Relation
from o2a.mappers.action_mapper import ActionMapper
from o2a.mappers.extensions.prepare_mapper_extension import PrepareMapperExtension
from o2a.o2a_libs.src.o2a_lib.property_utils import PropertySet
from o2a.utils import xml_utils
from o2a.utils.file_archive_extractors import FileExtractor, ArchiveExtractor
# pylint: disable=too-many-instance-attributes
from o2a.utils.xml_utils import get_tag_el_text
SPARK_TAG_VALUE = "value"
SPARK_TAG_NAME = "name"
SPARK_TAG_ARG = "arg"
SPARK_TAG_OPTS = "spark-opts"
SPARK_TAG_JOB_NAME = "name"
SPARK_TAG_CLASS = "class"
SPARK_TAG_JAR = "jar"
class SparkMapper(ActionMapper):
"""Maps Spark Action"""
def __init__(self, oozie_node: ET.Element, name: str, props: PropertySet, **kwargs):
ActionMapper.__init__(self, oozie_node=oozie_node, name=name, props=props, **kwargs)
self.java_class: Optional[str] = None
self.java_jar: Optional[str] = None
self.job_name: Optional[str] = None
self.jars: List[str] = []
self.application_args: List[str] = []
self.file_extractor = FileExtractor(oozie_node=oozie_node, props=self.props)
self.archive_extractor = ArchiveExtractor(oozie_node=oozie_node, props=self.props)
self.hdfs_files: List[str] = []
self.hdfs_archives: List[str] = []
self.dataproc_jars: List[str] = []
self.spark_opts: Dict[str, str] = {}
self.prepare_extension: PrepareMapperExtension = PrepareMapperExtension(self)
def on_parse_node(self):
super().on_parse_node()
_, self.hdfs_files = self.file_extractor.parse_node()
_, self.hdfs_archives = self.archive_extractor.parse_node()
self.java_jar = get_tag_el_text(self.oozie_node, tag=SPARK_TAG_JAR)
self.java_class = get_tag_el_text(self.oozie_node, tag=SPARK_TAG_CLASS)
if self.java_class and self.java_jar:
self.dataproc_jars = [self.java_jar]
self.java_jar = None
self.job_name = get_tag_el_text(self.oozie_node, tag=SPARK_TAG_JOB_NAME)
spark_opts = xml_utils.find_nodes_by_tag(self.oozie_node, SPARK_TAG_OPTS)
if spark_opts:
self.spark_opts.update(self._parse_spark_opts(spark_opts[0]))
self.application_args = xml_utils.get_tags_el_array_from_text(self.oozie_node, tag=SPARK_TAG_ARG)
@staticmethod
def _parse_spark_opts(spark_opts_node: ET.Element):
"""
Some examples of the spark-opts element:
--conf key1=value
--conf key2="value1 value2"
"""
conf: Dict[str, str] = {}
if spark_opts_node.text:
spark_opts = spark_opts_node.text.split("--")[1:]
else:
raise ParseException(f"Spark opts node has no text: {spark_opts_node}")
clean_opts = [opt.strip() for opt in spark_opts]
clean_opts_split = [opt.split(maxsplit=1) for opt in clean_opts]
for spark_opt in clean_opts_split:
# Can have multiple "--conf" in spark_opts
if spark_opt[0] == "conf":
key, _, value = spark_opt[1].partition("=")
# Value is required
if not value:
raise ParseException(
f"Incorrect parameter format. Expected format: key=value. Current value: {spark_opt}"
)
# Delete surrounding quotes
if len(value) > 2 and value[0] in ["'", '"'] and value:
value = value[1:-1]
conf[key] = value
# TODO: parse also other options (like --executor-memory 20G --num-executors 50 and many more)
# see: https://oozie.apache.org/docs/5.1.0/DG_SparkActionExtension.html#PySpark_with_Spark_Action
return conf
def to_tasks_and_relations(self):
action_task = Task(
task_id=self.name,
template_name="spark.tpl",
template_params=dict(
job_name=self.job_name,
spark_job=dict(
args=self.application_args,
jar_file_uris=self.dataproc_jars,
file_uris=self.hdfs_files,
archive_uris=self.hdfs_archives,
properties=self.spark_opts,
main_jar_file_uri=self.java_jar,
main_class=self.java_class
)
),
)
tasks = [action_task]
relations: List[Relation] = []
prepare_task = self.prepare_extension.get_prepare_task()
if prepare_task:
tasks, relations = self.prepend_task(prepare_task, tasks, relations)
return tasks, relations
def required_imports(self) -> Set[str]:
# Bash are for the potential prepare statement
return {
"from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator",
"from airflow.operators import bash",
"from airflow.operators import empty",
}