in metaflow/runtime.py [0:0]
def _queue_task_join(self, task, next_steps):
# if the next step is a join, we need to check that
# all input tasks for the join have finished before queuing it.
# CHECK: this condition should be enforced by the linter but
# let's assert that the assumption holds
if len(next_steps) > 1:
msg = (
"Step *{step}* transitions to a join and another "
"step. The join must be the only transition."
)
raise MetaflowInternalError(task, msg.format(step=task.step))
else:
next_step = next_steps[0]
unbounded_foreach = not task.results.is_none("_unbounded_foreach")
if unbounded_foreach:
# Before we queue the join, do some post-processing of runtime state
# (_finished, _is_cloned) for the (sibling) mapper tasks.
# Update state of (sibling) mapper tasks for control task.
if task.ubf_context == UBF_CONTROL:
mapper_tasks = task.results.get("_control_mapper_tasks")
if not mapper_tasks:
msg = (
"Step *{step}* has a control task which didn't "
"specify the artifact *_control_mapper_tasks* for "
"the subsequent *{join}* step."
)
raise MetaflowInternalError(
msg.format(step=task.step, join=next_steps[0])
)
elif not (
isinstance(mapper_tasks, list)
and isinstance(mapper_tasks[0], unicode_type)
):
msg = (
"Step *{step}* has a control task which didn't "
"specify the artifact *_control_mapper_tasks* as a "
"list of strings but instead specified it as {typ} "
"with elements of {elem_typ}."
)
raise MetaflowInternalError(
msg.format(
step=task.step,
typ=type(mapper_tasks),
elem_type=type(mapper_tasks[0]),
)
)
num_splits = len(mapper_tasks)
self._control_num_splits[task.path] = num_splits
# If the control task is cloned, all mapper tasks should have been cloned
# as well, so we no longer need to handle cloning of mapper tasks in runtime.
# Update _finished if we are not cloned. If we were cloned, we already
# updated _finished with the new tasks. Note that the *value* of mapper
# tasks is incorrect and contains the pathspec of the *cloned* run
# but we don't use it for anything. We could look to clean it up though
if not task.is_cloned:
_, foreach_stack = task.finished_id
top = foreach_stack[-1]
bottom = list(foreach_stack[:-1])
for i in range(num_splits):
s = tuple(bottom + [top._replace(index=i)])
self._finished[(task.step, s)] = mapper_tasks[i]
self._is_cloned[mapper_tasks[i]] = False
# Find and check status of control task and retrieve its pathspec
# for retrieving unbounded foreach cardinality.
_, foreach_stack = task.finished_id
top = foreach_stack[-1]
bottom = list(foreach_stack[:-1])
s = tuple(bottom + [top._replace(index=None)])
# UBF control can also be the first task of the list. Then
# it will have index=0 instead of index=None.
if task.results.get("_control_task_is_mapper_zero", False):
s = tuple(bottom + [top._replace(index=0)])
control_path = self._finished.get((task.step, s))
if control_path:
# Control task was successful.
# Additionally check the state of (sibling) mapper tasks as well
# (for the sake of resume) before queueing join task.
num_splits = self._control_num_splits[control_path]
required_tasks = []
for i in range(num_splits):
s = tuple(bottom + [top._replace(index=i)])
required_tasks.append(self._finished.get((task.step, s)))
if all(required_tasks):
index = self._translate_index(task, next_step, "join")
# all tasks to be joined are ready. Schedule the next join step.
self._queue_push(
next_step,
{"input_paths": required_tasks, "join_type": "foreach"},
index,
)
else:
# matching_split is the split-parent of the finished task
matching_split = self._graph[self._graph[next_step].split_parents[-1]]
_, foreach_stack = task.finished_id
index = ""
if matching_split.type == "foreach":
# next step is a foreach join
def siblings(foreach_stack):
top = foreach_stack[-1]
bottom = list(foreach_stack[:-1])
for index in range(top.num_splits):
yield tuple(bottom + [top._replace(index=index)])
# required tasks are all split-siblings of the finished task
required_tasks = [
self._finished.get((task.step, s)) for s in siblings(foreach_stack)
]
join_type = "foreach"
index = self._translate_index(task, next_step, "join")
else:
# next step is a split
# required tasks are all branches joined by the next step
required_tasks = [
self._finished.get((step, foreach_stack))
for step in self._graph[next_step].in_funcs
]
join_type = "linear"
index = self._translate_index(task, next_step, "linear")
if all(required_tasks):
# all tasks to be joined are ready. Schedule the next join step.
self._queue_push(
next_step,
{"input_paths": required_tasks, "join_type": join_type},
index,
)