def _dag_templates()

in metaflow/plugins/argo/argo_workflows.py [0:0]


    def _dag_templates(self):
        def _visit(
            node,
            exit_node=None,
            templates=None,
            dag_tasks=None,
            parent_foreach=None,
        ):  # Returns Tuple[List[Template], List[DAGTask]]
            """ """
            # Every for-each node results in a separate subDAG and an equivalent
            # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
            # has a unique name - the top-level DAGTemplate is named as the name of
            # the flow and the subDAG DAGTemplates are named after the (only) descendant
            # of the for-each node.

            # Emit if we have reached the end of the sub workflow
            if dag_tasks is None:
                dag_tasks = []
            if templates is None:
                templates = []
            if exit_node is not None and exit_node is node.name:
                return templates, dag_tasks
            if node.name == "start":
                # Start node has no dependencies.
                dag_task = DAGTask(self._sanitize(node.name)).template(
                    self._sanitize(node.name)
                )
            elif (
                node.is_inside_foreach
                and self.graph[node.in_funcs[0]].type == "foreach"
                and not self.graph[node.in_funcs[0]].parallel_foreach
                # We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics)
                # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
                # A `regular` foreach is basically any arbitrary kind of foreach.
            ):
                # Child of a foreach node needs input-paths as well as split-index
                # This child is the first node of the sub workflow and has no dependency

                parameters = [
                    Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
                    Parameter("split-index").value("{{inputs.parameters.split-index}}"),
                ]
                dag_task = (
                    DAGTask(self._sanitize(node.name))
                    .template(self._sanitize(node.name))
                    .arguments(Arguments().parameters(parameters))
                )
            elif node.parallel_step:
                # This is the step where the @parallel decorator is defined.
                # Since this DAGTask will call the for the `resource` [based templates]
                # (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/)
                # we have certain constraints on the way we can pass information inside the Jobset manifest
                # [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates)
                # to the `inputs.parameters` so we will pass down ANY/ALL information using the
                # input parameters.
                # We define the usual parameters like input-paths/split-index etc. but we will also
                # define the following:
                # - `workerCount`:  parameter which will be used to determine the number of
                #                   parallel worker jobs
                # - `jobset-name`:  parameter which will be used to determine the name of the jobset.
                #                   This parameter needs to be dynamic so that when we have retries we don't
                #                   end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names)
                # - `retryCount`:   parameter which will be used to determine the number of retries
                #                   This parameter will *only* be available within the container templates like we
                #                   have it for all other DAGTasks and NOT for custom kubernetes resource templates.
                #                   So as a work-around, we will set it as the `retryCount` parameter instead of
                #                   setting it as a {{ retries }} in the CLI code. Once set as a input parameter,
                #                   we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}`
                # - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter
                #                   contains the relevant amount of entropy to ensure that task-ids and jobset names
                #                   are uniquish. We will also use this in the join task to construct the task-ids of
                #                   all parallel tasks since the task-ids for parallel task are minted formulaically.
                parameters = [
                    Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
                    Parameter("num-parallel").value(
                        "{{inputs.parameters.num-parallel}}"
                    ),
                    Parameter("split-index").value("{{inputs.parameters.split-index}}"),
                    Parameter("task-id-entropy").value(
                        "{{inputs.parameters.task-id-entropy}}"
                    ),
                    # we cant just use hyphens with sprig.
                    # https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948
                    Parameter("workerCount").value(
                        "{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}"
                    ),
                ]
                if any(d.name == "retry" for d in node.decorators):
                    parameters.extend(
                        [
                            Parameter("retryCount").value("{{retries}}"),
                            # The job-setname needs to be unique for each retry
                            # and we cannot use the `generateName` field in the
                            # Jobset Manifest since we need to construct the subdomain
                            # and control pod domain name pre-hand. So we will use
                            # the retry count to ensure that the jobset name is unique
                            Parameter("jobset-name").value(
                                "js-{{inputs.parameters.task-id-entropy}}{{retries}}",
                            ),
                        ]
                    )
                else:
                    parameters.extend(
                        [
                            Parameter("jobset-name").value(
                                "js-{{inputs.parameters.task-id-entropy}}",
                            )
                        ]
                    )

                dag_task = (
                    DAGTask(self._sanitize(node.name))
                    .template(self._sanitize(node.name))
                    .arguments(Arguments().parameters(parameters))
                )
            else:
                # Every other node needs only input-paths
                parameters = [
                    Parameter("input-paths").value(
                        compress_list(
                            [
                                "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
                                % (n, self._sanitize(n))
                                for n in node.in_funcs
                            ],
                            # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
                            zlibmin=inf,
                        )
                    )
                ]
                # NOTE: Due to limitations with Argo Workflows Parameter size we
                #       can not pass arbitrarily large lists of task id's to join tasks.
                #       Instead we ensure that task id's for foreach tasks can be
                #       deduced deterministically and pass the relevant information to
                #       the join task.
                #
                #       We need to add the split-index and root-input-path for the last
                #       step in any foreach scope and use these to generate the task id,
                #       as the join step uses the root and the cardinality of the
                #       foreach scope to generate the required id's.
                if (
                    node.is_inside_foreach
                    and self.graph[node.out_funcs[0]].type == "join"
                ):
                    if any(
                        self.graph[parent].matching_join
                        == self.graph[node.out_funcs[0]].name
                        and self.graph[parent].type == "foreach"
                        for parent in self.graph[node.out_funcs[0]].split_parents
                    ):
                        parameters.extend(
                            [
                                Parameter("split-index").value(
                                    "{{inputs.parameters.split-index}}"
                                ),
                                Parameter("root-input-path").value(
                                    "{{inputs.parameters.input-paths}}"
                                ),
                            ]
                        )

                dag_task = (
                    DAGTask(self._sanitize(node.name))
                    .dependencies(
                        [self._sanitize(in_func) for in_func in node.in_funcs]
                    )
                    .template(self._sanitize(node.name))
                    .arguments(Arguments().parameters(parameters))
                )

            dag_tasks.append(dag_task)
            # End the workflow if we have reached the end of the flow
            if node.type == "end":
                return [
                    Template(self.flow.name).dag(
                        DAGTemplate().fail_fast().tasks(dag_tasks)
                    )
                ] + templates, dag_tasks
            # For split nodes traverse all the children
            if node.type == "split":
                for n in node.out_funcs:
                    _visit(
                        self.graph[n],
                        node.matching_join,
                        templates,
                        dag_tasks,
                        parent_foreach,
                    )
                return _visit(
                    self.graph[node.matching_join],
                    exit_node,
                    templates,
                    dag_tasks,
                    parent_foreach,
                )
            # For foreach nodes generate a new sub DAGTemplate
            # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
            elif node.type == "foreach":
                foreach_template_name = self._sanitize(
                    "%s-foreach-%s"
                    % (
                        node.name,
                        "parallel" if node.parallel_foreach else node.foreach_param,
                        # Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
                        # vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
                        # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
                        # of foreach.
                    )
                )

                # There are two separate "DAGTask"s created for the foreach node.
                # - The first one is a "jump-off" DAGTask where we propagate the
                # input-paths and split-index. This thing doesn't create
                # any actual containers and it responsible for only propagating
                # the parameters.
                # - The DAGTask that follows first DAGTask is the one
                # that uses the ContainerTemplate. This DAGTask is named the same
                # thing as the foreach node. We will leverage a similar pattern for the
                # @parallel tasks.
                #
                foreach_task = (
                    DAGTask(foreach_template_name)
                    .dependencies([self._sanitize(node.name)])
                    .template(foreach_template_name)
                    .arguments(
                        Arguments().parameters(
                            [
                                Parameter("input-paths").value(
                                    "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
                                    % (node.name, self._sanitize(node.name))
                                ),
                                Parameter("split-index").value("{{item}}"),
                            ]
                            + (
                                [
                                    Parameter("root-input-path").value(
                                        "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
                                        % (node.name, self._sanitize(node.name))
                                    ),
                                ]
                                if parent_foreach
                                else []
                            )
                            + (
                                # Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach
                                [
                                    Parameter("num-parallel").value(
                                        "{{tasks.%s.outputs.parameters.num-parallel}}"
                                        % self._sanitize(node.name)
                                    ),
                                    Parameter("task-id-entropy").value(
                                        "{{tasks.%s.outputs.parameters.task-id-entropy}}"
                                        % self._sanitize(node.name)
                                    ),
                                ]
                                if node.parallel_foreach
                                else []
                            )
                        )
                    )
                    .with_param(
                        # For @parallel workloads `num-splits` will be explicitly set to one so that
                        # we can piggyback on the current mechanism with which we leverage argo.
                        "{{tasks.%s.outputs.parameters.num-splits}}"
                        % self._sanitize(node.name)
                    )
                )
                dag_tasks.append(foreach_task)
                templates, dag_tasks_1 = _visit(
                    self.graph[node.out_funcs[0]],
                    node.matching_join,
                    templates,
                    [],
                    node.name,
                )

                # How do foreach's work on Argo:
                # Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end)
                # With argo we will :
                # (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end)
                # The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths.
                # to the actual foreach task.
                templates.append(
                    Template(foreach_template_name)
                    .inputs(
                        Inputs().parameters(
                            [Parameter("input-paths"), Parameter("split-index")]
                            + ([Parameter("root-input-path")] if parent_foreach else [])
                            + (
                                [
                                    Parameter("num-parallel"),
                                    Parameter("task-id-entropy"),
                                    # Parameter("workerCount")
                                ]
                                if node.parallel_foreach
                                else []
                            )
                        )
                    )
                    .outputs(
                        Outputs().parameters(
                            [
                                # non @parallel tasks set task-ids as outputs
                                Parameter("task-id").valueFrom(
                                    {
                                        "parameter": "{{tasks.%s.outputs.parameters.task-id}}"
                                        % self._sanitize(
                                            self.graph[node.matching_join].in_funcs[0]
                                        )
                                    }
                                )
                            ]
                            if not node.parallel_foreach
                            else [
                                # @parallel tasks set `task-id-entropy` and `num-parallel`
                                # as outputs so task-ids can be derived in the join step.
                                # Both of these values should be propagated from the
                                # jobset labels.
                                Parameter("num-parallel").valueFrom(
                                    {
                                        "parameter": "{{tasks.%s.outputs.parameters.num-parallel}}"
                                        % self._sanitize(
                                            self.graph[node.matching_join].in_funcs[0]
                                        )
                                    }
                                ),
                                Parameter("task-id-entropy").valueFrom(
                                    {
                                        "parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}"
                                        % self._sanitize(
                                            self.graph[node.matching_join].in_funcs[0]
                                        )
                                    }
                                ),
                            ]
                        )
                    )
                    .dag(DAGTemplate().fail_fast().tasks(dag_tasks_1))
                )

                join_foreach_task = (
                    DAGTask(self._sanitize(self.graph[node.matching_join].name))
                    .template(self._sanitize(self.graph[node.matching_join].name))
                    .dependencies([foreach_template_name])
                    .arguments(
                        Arguments().parameters(
                            (
                                [
                                    Parameter("input-paths").value(
                                        "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
                                        % (node.name, self._sanitize(node.name))
                                    ),
                                    Parameter("split-cardinality").value(
                                        "{{tasks.%s.outputs.parameters.split-cardinality}}"
                                        % self._sanitize(node.name)
                                    ),
                                ]
                                if not node.parallel_foreach
                                else [
                                    Parameter("num-parallel").value(
                                        "{{tasks.%s.outputs.parameters.num-parallel}}"
                                        % self._sanitize(node.name)
                                    ),
                                    Parameter("task-id-entropy").value(
                                        "{{tasks.%s.outputs.parameters.task-id-entropy}}"
                                        % self._sanitize(node.name)
                                    ),
                                ]
                            )
                            + (
                                [
                                    Parameter("split-index").value(
                                        # TODO : Pass down these parameters to the jobset stuff.
                                        "{{inputs.parameters.split-index}}"
                                    ),
                                    Parameter("root-input-path").value(
                                        "{{inputs.parameters.input-paths}}"
                                    ),
                                ]
                                if parent_foreach
                                else []
                            )
                        )
                    )
                )
                dag_tasks.append(join_foreach_task)
                return _visit(
                    self.graph[self.graph[node.matching_join].out_funcs[0]],
                    exit_node,
                    templates,
                    dag_tasks,
                    parent_foreach,
                )
            # For linear nodes continue traversing to the next node
            if node.type in ("linear", "join", "start"):
                return _visit(
                    self.graph[node.out_funcs[0]],
                    exit_node,
                    templates,
                    dag_tasks,
                    parent_foreach,
                )
            else:
                raise ArgoWorkflowsException(
                    "Node type *%s* for step *%s* is not currently supported by "
                    "Argo Workflows." % (node.type, node.name)
                )

        # Generate daemon tasks
        daemon_tasks = [
            DAGTask("%s-task" % daemon_template.name).template(daemon_template.name)
            for daemon_template in self._daemon_templates()
        ]

        templates, _ = _visit(node=self.graph["start"], dag_tasks=daemon_tasks)
        return templates