o2a/converter/workflow.py (90 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Workflow"""
import os
from collections import OrderedDict
from typing import Set, Dict, Type, List
from o2a.converter.constants import HDFS_FOLDER, LIB_FOLDER
from o2a.converter.oozie_node import OozieNode
from o2a.converter.relation import Relation
from o2a.converter.task_group import TaskGroup
from o2a.utils.file_utils import get_lib_files
class Workflow:
"""Class for Workflow"""
def __init__(
self,
input_directory_path: str,
output_directory_path: str,
dag_name: str,
task_group_relations: Set[Relation] = None,
nodes: Dict[str, OozieNode] = None,
task_groups: Dict[str, TaskGroup] = None,
dependencies: Set[str] = None,
) -> None:
self.input_directory_path = input_directory_path
self.output_directory_path = output_directory_path
self.dag_name = dag_name
self.task_group_relations = task_group_relations or set()
# Dictionary is ordered purely for output being somewhat ordered the
# same as how Oozie workflow was parsed.
self.nodes = nodes or OrderedDict()
self.task_groups = task_groups or OrderedDict()
# These are the general dependencies required that every operator
# requires.
self.dependencies = dependencies or {
"import shlex",
"import datetime",
"from o2a_lib.property_utils import PropertySet",
"from o2a_lib import functions",
"from airflow import models",
"from airflow.utils.trigger_rule import TriggerRule",
"from airflow.utils import dates",
"from airflow.operators import bash, empty",
}
self.library_folder = os.path.join(self.input_directory_path, HDFS_FOLDER, LIB_FOLDER)
self.jar_files = get_lib_files(self.library_folder, extension=".jar")
def get_nodes_by_type(self, mapper_type: Type):
return [node for node in self.nodes.values() if isinstance(node.mapper, mapper_type)]
def find_upstream_nodes(self, target_node):
result = []
for node in self.nodes.values():
if target_node.name in node.downstream_names or target_node.name == node.error_downstream_name:
result.append(node)
return result
def find_upstream_task_group(self, target_task_group) -> List[TaskGroup]:
result = []
for task_group in self.task_groups.values():
if (
target_task_group.name in task_group.downstream_names
or target_task_group.name == task_group.error_downstream_name
):
result.append(task_group)
return result
def get_task_group_without_upstream(self) -> List[TaskGroup]:
task_groups = []
for task_group in self.task_groups.values():
upstream_task_group = self.find_upstream_task_group(task_group)
if not upstream_task_group:
task_groups.append(task_group)
return task_groups
def get_task_group_without_ok_downstream(self):
task_groups = []
for task_group in self.task_groups.values():
if not task_group.downstream_names:
task_groups.append(task_group)
return task_groups
def get_task_group_without_error_downstream(self):
task_groups = []
for task_group in self.task_groups.values():
if not task_group.error_downstream_name:
task_groups.append(task_group)
return task_groups
def remove_node(self, node_to_delete: OozieNode):
del self.nodes[node_to_delete.name]
for node in self.nodes.values():
if node_to_delete.name in node.downstream_names:
node.downstream_names.remove(node_to_delete.name)
if node.error_downstream_name == node_to_delete.name:
node.error_downstream_name = None
def __repr__(self) -> str:
return (
f'Workflow(dag_name="{self.dag_name}", input_directory_path="{self.input_directory_path}", '
f'output_directory_path="{self.output_directory_path}", relations={self.task_group_relations}, '
f"nodes={self.nodes.keys()}, dependencies={self.dependencies})"
)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False