def _queue_task_join()

in metaflow/runtime.py [0:0]


    def _queue_task_join(self, task, next_steps):
        # if the next step is a join, we need to check that
        # all input tasks for the join have finished before queuing it.

        # CHECK: this condition should be enforced by the linter but
        # let's assert that the assumption holds
        if len(next_steps) > 1:
            msg = (
                "Step *{step}* transitions to a join and another "
                "step. The join must be the only transition."
            )
            raise MetaflowInternalError(task, msg.format(step=task.step))
        else:
            next_step = next_steps[0]

        unbounded_foreach = not task.results.is_none("_unbounded_foreach")

        if unbounded_foreach:
            # Before we queue the join, do some post-processing of runtime state
            # (_finished, _is_cloned) for the (sibling) mapper tasks.
            # Update state of (sibling) mapper tasks for control task.
            if task.ubf_context == UBF_CONTROL:
                mapper_tasks = task.results.get("_control_mapper_tasks")
                if not mapper_tasks:
                    msg = (
                        "Step *{step}* has a control task which didn't "
                        "specify the artifact *_control_mapper_tasks* for "
                        "the subsequent *{join}* step."
                    )
                    raise MetaflowInternalError(
                        msg.format(step=task.step, join=next_steps[0])
                    )
                elif not (
                    isinstance(mapper_tasks, list)
                    and isinstance(mapper_tasks[0], unicode_type)
                ):
                    msg = (
                        "Step *{step}* has a control task which didn't "
                        "specify the artifact *_control_mapper_tasks* as a "
                        "list of strings but instead specified it as {typ} "
                        "with elements of {elem_typ}."
                    )
                    raise MetaflowInternalError(
                        msg.format(
                            step=task.step,
                            typ=type(mapper_tasks),
                            elem_type=type(mapper_tasks[0]),
                        )
                    )
                num_splits = len(mapper_tasks)
                self._control_num_splits[task.path] = num_splits

                # If the control task is cloned, all mapper tasks should have been cloned
                # as well, so we no longer need to handle cloning of mapper tasks in runtime.

                # Update _finished if we are not cloned. If we were cloned, we already
                # updated _finished with the new tasks. Note that the *value* of mapper
                # tasks is incorrect and contains the pathspec of the *cloned* run
                # but we don't use it for anything. We could look to clean it up though
                if not task.is_cloned:
                    _, foreach_stack = task.finished_id
                    top = foreach_stack[-1]
                    bottom = list(foreach_stack[:-1])
                    for i in range(num_splits):
                        s = tuple(bottom + [top._replace(index=i)])
                        self._finished[(task.step, s)] = mapper_tasks[i]
                        self._is_cloned[mapper_tasks[i]] = False

            # Find and check status of control task and retrieve its pathspec
            # for retrieving unbounded foreach cardinality.
            _, foreach_stack = task.finished_id
            top = foreach_stack[-1]
            bottom = list(foreach_stack[:-1])
            s = tuple(bottom + [top._replace(index=None)])

            # UBF control can also be the first task of the list. Then
            # it will have index=0 instead of index=None.
            if task.results.get("_control_task_is_mapper_zero", False):
                s = tuple(bottom + [top._replace(index=0)])
            control_path = self._finished.get((task.step, s))
            if control_path:
                # Control task was successful.
                # Additionally check the state of (sibling) mapper tasks as well
                # (for the sake of resume) before queueing join task.
                num_splits = self._control_num_splits[control_path]
                required_tasks = []
                for i in range(num_splits):
                    s = tuple(bottom + [top._replace(index=i)])
                    required_tasks.append(self._finished.get((task.step, s)))

                if all(required_tasks):
                    index = self._translate_index(task, next_step, "join")
                    # all tasks to be joined are ready. Schedule the next join step.
                    self._queue_push(
                        next_step,
                        {"input_paths": required_tasks, "join_type": "foreach"},
                        index,
                    )
        else:
            # matching_split is the split-parent of the finished task
            matching_split = self._graph[self._graph[next_step].split_parents[-1]]
            _, foreach_stack = task.finished_id
            index = ""
            if matching_split.type == "foreach":
                # next step is a foreach join

                def siblings(foreach_stack):
                    top = foreach_stack[-1]
                    bottom = list(foreach_stack[:-1])
                    for index in range(top.num_splits):
                        yield tuple(bottom + [top._replace(index=index)])

                # required tasks are all split-siblings of the finished task
                required_tasks = [
                    self._finished.get((task.step, s)) for s in siblings(foreach_stack)
                ]
                join_type = "foreach"
                index = self._translate_index(task, next_step, "join")
            else:
                # next step is a split
                # required tasks are all branches joined by the next step
                required_tasks = [
                    self._finished.get((step, foreach_stack))
                    for step in self._graph[next_step].in_funcs
                ]
                join_type = "linear"
                index = self._translate_index(task, next_step, "linear")

            if all(required_tasks):
                # all tasks to be joined are ready. Schedule the next join step.
                self._queue_push(
                    next_step,
                    {"input_paths": required_tasks, "join_type": join_type},
                    index,
                )