o2a/converter/task_group.py (88 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for Airflow's task group"""
# noinspection PyPackageRequirements
from typing import List, Optional
from airflow.utils.trigger_rule import TriggerRule
from o2a.converter.exceptions import O2AException
from o2a.converter.task import Task
from o2a.converter.relation import Relation
class TaskGroup:
"""Airflow's tasks group
It is created as a result of converting the ParsedActionNode object. It contains all of its
information except the mapper.
Additionally, it contains an Airflow tasks, relations and python imports statements.
"""
def __init__(
self,
name,
tasks,
relations=None,
downstream_names=None,
error_downstream_name=None,
dependencies=None,
):
self.name = name
self.tasks: List[Task] = tasks or []
self.relations: List[Relation] = relations or []
self.downstream_names: List[str] = downstream_names or []
self.error_downstream_name: Optional[str] = error_downstream_name
self.dependencies: List[str] = dependencies or []
self.error_handler_task: Optional[Task] = None
self.ok_handler_task: Optional[Task] = None
@property
def first_task_id(self) -> str:
"""
Returns task_id of first task in group
"""
return self.tasks[0].task_id
@property
def last_task_id_of_ok_flow(self) -> str:
"""
Returns task_id of last task in group
"""
if self.ok_handler_task:
return self.ok_handler_task.task_id
return self.tasks[-1].task_id
@property
def last_task_id_of_error_flow(self) -> str:
"""
Returns task_id of last task in group
"""
if not self.error_handler_task:
raise O2AException(
"Unsupported state. The Error handler task ID was requested before it was created."
)
return self.error_handler_task.task_id
def add_state_handler_if_needed(self):
"""
Add additional tasks and relations to handle error and ok flow.
If the error path is specified, additional relations and task are added to handle
the error state.
If the error path and the ok path is specified, additional relations and task are added
to handle the ok path and the error path.
If the error path and the ok path is not-specified, no action is performed.
"""
if not self.error_downstream_name:
return
error_handler_task_id = self.name + "_error"
error_handler = Task(
task_id=error_handler_task_id, template_name="dummy.tpl", trigger_rule=TriggerRule.ONE_FAILED
)
self.error_handler_task = error_handler
new_relations = (
Relation(from_task_id=t.task_id, to_task_id=error_handler_task_id, is_error=True)
for t in self.tasks
)
self.relations.extend(new_relations)
if not self.downstream_names:
return
ok_handler_task_id = self.name + "_ok"
ok_handler = Task(
task_id=ok_handler_task_id, template_name="dummy.tpl", trigger_rule=TriggerRule.ONE_SUCCESS
)
self.ok_handler_task = ok_handler
self.relations.append(Relation(from_task_id=self.tasks[-1].task_id, to_task_id=ok_handler_task_id))
@property
def all_tasks(self):
all_tasks = [*self.tasks]
if self.error_handler_task:
all_tasks.append(self.error_handler_task)
if self.ok_handler_task:
all_tasks.append(self.ok_handler_task)
return all_tasks
def __repr__(self) -> str:
return (
f"TaskGroup(name={self.name}, "
f"downstream_names={self.downstream_names}, "
f"error_downstream_name={self.error_downstream_name}, "
f"tasks={self.tasks}, relations={self.relations})"
)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False
class ActionTaskGroup(TaskGroup):
pass
class ControlTaskGroup(TaskGroup):
pass
class NotificationTaskGroup(TaskGroup):
pass
class StatusNotificationTaskGroup(NotificationTaskGroup):
pass
class TransitionNotificationTaskGroup(NotificationTaskGroup):
pass