in metaflow/plugins/argo/argo_workflows.py [0:0]
def _dag_templates(self):
def _visit(
node,
exit_node=None,
templates=None,
dag_tasks=None,
parent_foreach=None,
): # Returns Tuple[List[Template], List[DAGTask]]
""" """
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
# the flow and the subDAG DAGTemplates are named after the (only) descendant
# of the for-each node.
# Emit if we have reached the end of the sub workflow
if dag_tasks is None:
dag_tasks = []
if templates is None:
templates = []
if exit_node is not None and exit_node is node.name:
return templates, dag_tasks
if node.name == "start":
# Start node has no dependencies.
dag_task = DAGTask(self._sanitize(node.name)).template(
self._sanitize(node.name)
)
elif (
node.is_inside_foreach
and self.graph[node.in_funcs[0]].type == "foreach"
and not self.graph[node.in_funcs[0]].parallel_foreach
# We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics)
# vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
# A `regular` foreach is basically any arbitrary kind of foreach.
):
# Child of a foreach node needs input-paths as well as split-index
# This child is the first node of the sub workflow and has no dependency
parameters = [
Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
Parameter("split-index").value("{{inputs.parameters.split-index}}"),
]
dag_task = (
DAGTask(self._sanitize(node.name))
.template(self._sanitize(node.name))
.arguments(Arguments().parameters(parameters))
)
elif node.parallel_step:
# This is the step where the @parallel decorator is defined.
# Since this DAGTask will call the for the `resource` [based templates]
# (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/)
# we have certain constraints on the way we can pass information inside the Jobset manifest
# [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates)
# to the `inputs.parameters` so we will pass down ANY/ALL information using the
# input parameters.
# We define the usual parameters like input-paths/split-index etc. but we will also
# define the following:
# - `workerCount`: parameter which will be used to determine the number of
# parallel worker jobs
# - `jobset-name`: parameter which will be used to determine the name of the jobset.
# This parameter needs to be dynamic so that when we have retries we don't
# end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names)
# - `retryCount`: parameter which will be used to determine the number of retries
# This parameter will *only* be available within the container templates like we
# have it for all other DAGTasks and NOT for custom kubernetes resource templates.
# So as a work-around, we will set it as the `retryCount` parameter instead of
# setting it as a {{ retries }} in the CLI code. Once set as a input parameter,
# we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}`
# - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter
# contains the relevant amount of entropy to ensure that task-ids and jobset names
# are uniquish. We will also use this in the join task to construct the task-ids of
# all parallel tasks since the task-ids for parallel task are minted formulaically.
parameters = [
Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
Parameter("num-parallel").value(
"{{inputs.parameters.num-parallel}}"
),
Parameter("split-index").value("{{inputs.parameters.split-index}}"),
Parameter("task-id-entropy").value(
"{{inputs.parameters.task-id-entropy}}"
),
# we cant just use hyphens with sprig.
# https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948
Parameter("workerCount").value(
"{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}"
),
]
if any(d.name == "retry" for d in node.decorators):
parameters.extend(
[
Parameter("retryCount").value("{{retries}}"),
# The job-setname needs to be unique for each retry
# and we cannot use the `generateName` field in the
# Jobset Manifest since we need to construct the subdomain
# and control pod domain name pre-hand. So we will use
# the retry count to ensure that the jobset name is unique
Parameter("jobset-name").value(
"js-{{inputs.parameters.task-id-entropy}}{{retries}}",
),
]
)
else:
parameters.extend(
[
Parameter("jobset-name").value(
"js-{{inputs.parameters.task-id-entropy}}",
)
]
)
dag_task = (
DAGTask(self._sanitize(node.name))
.template(self._sanitize(node.name))
.arguments(Arguments().parameters(parameters))
)
else:
# Every other node needs only input-paths
parameters = [
Parameter("input-paths").value(
compress_list(
[
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (n, self._sanitize(n))
for n in node.in_funcs
],
# NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
zlibmin=inf,
)
)
]
# NOTE: Due to limitations with Argo Workflows Parameter size we
# can not pass arbitrarily large lists of task id's to join tasks.
# Instead we ensure that task id's for foreach tasks can be
# deduced deterministically and pass the relevant information to
# the join task.
#
# We need to add the split-index and root-input-path for the last
# step in any foreach scope and use these to generate the task id,
# as the join step uses the root and the cardinality of the
# foreach scope to generate the required id's.
if (
node.is_inside_foreach
and self.graph[node.out_funcs[0]].type == "join"
):
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
and self.graph[parent].type == "foreach"
for parent in self.graph[node.out_funcs[0]].split_parents
):
parameters.extend(
[
Parameter("split-index").value(
"{{inputs.parameters.split-index}}"
),
Parameter("root-input-path").value(
"{{inputs.parameters.input-paths}}"
),
]
)
dag_task = (
DAGTask(self._sanitize(node.name))
.dependencies(
[self._sanitize(in_func) for in_func in node.in_funcs]
)
.template(self._sanitize(node.name))
.arguments(Arguments().parameters(parameters))
)
dag_tasks.append(dag_task)
# End the workflow if we have reached the end of the flow
if node.type == "end":
return [
Template(self.flow.name).dag(
DAGTemplate().fail_fast().tasks(dag_tasks)
)
] + templates, dag_tasks
# For split nodes traverse all the children
if node.type == "split":
for n in node.out_funcs:
_visit(
self.graph[n],
node.matching_join,
templates,
dag_tasks,
parent_foreach,
)
return _visit(
self.graph[node.matching_join],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For foreach nodes generate a new sub DAGTemplate
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
elif node.type == "foreach":
foreach_template_name = self._sanitize(
"%s-foreach-%s"
% (
node.name,
"parallel" if node.parallel_foreach else node.foreach_param,
# Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
# vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
# we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
# of foreach.
)
)
# There are two separate "DAGTask"s created for the foreach node.
# - The first one is a "jump-off" DAGTask where we propagate the
# input-paths and split-index. This thing doesn't create
# any actual containers and it responsible for only propagating
# the parameters.
# - The DAGTask that follows first DAGTask is the one
# that uses the ContainerTemplate. This DAGTask is named the same
# thing as the foreach node. We will leverage a similar pattern for the
# @parallel tasks.
#
foreach_task = (
DAGTask(foreach_template_name)
.dependencies([self._sanitize(node.name)])
.template(foreach_template_name)
.arguments(
Arguments().parameters(
[
Parameter("input-paths").value(
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (node.name, self._sanitize(node.name))
),
Parameter("split-index").value("{{item}}"),
]
+ (
[
Parameter("root-input-path").value(
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (node.name, self._sanitize(node.name))
),
]
if parent_foreach
else []
)
+ (
# Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach
[
Parameter("num-parallel").value(
"{{tasks.%s.outputs.parameters.num-parallel}}"
% self._sanitize(node.name)
),
Parameter("task-id-entropy").value(
"{{tasks.%s.outputs.parameters.task-id-entropy}}"
% self._sanitize(node.name)
),
]
if node.parallel_foreach
else []
)
)
)
.with_param(
# For @parallel workloads `num-splits` will be explicitly set to one so that
# we can piggyback on the current mechanism with which we leverage argo.
"{{tasks.%s.outputs.parameters.num-splits}}"
% self._sanitize(node.name)
)
)
dag_tasks.append(foreach_task)
templates, dag_tasks_1 = _visit(
self.graph[node.out_funcs[0]],
node.matching_join,
templates,
[],
node.name,
)
# How do foreach's work on Argo:
# Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end)
# With argo we will :
# (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end)
# The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths.
# to the actual foreach task.
templates.append(
Template(foreach_template_name)
.inputs(
Inputs().parameters(
[Parameter("input-paths"), Parameter("split-index")]
+ ([Parameter("root-input-path")] if parent_foreach else [])
+ (
[
Parameter("num-parallel"),
Parameter("task-id-entropy"),
# Parameter("workerCount")
]
if node.parallel_foreach
else []
)
)
)
.outputs(
Outputs().parameters(
[
# non @parallel tasks set task-ids as outputs
Parameter("task-id").valueFrom(
{
"parameter": "{{tasks.%s.outputs.parameters.task-id}}"
% self._sanitize(
self.graph[node.matching_join].in_funcs[0]
)
}
)
]
if not node.parallel_foreach
else [
# @parallel tasks set `task-id-entropy` and `num-parallel`
# as outputs so task-ids can be derived in the join step.
# Both of these values should be propagated from the
# jobset labels.
Parameter("num-parallel").valueFrom(
{
"parameter": "{{tasks.%s.outputs.parameters.num-parallel}}"
% self._sanitize(
self.graph[node.matching_join].in_funcs[0]
)
}
),
Parameter("task-id-entropy").valueFrom(
{
"parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}"
% self._sanitize(
self.graph[node.matching_join].in_funcs[0]
)
}
),
]
)
)
.dag(DAGTemplate().fail_fast().tasks(dag_tasks_1))
)
join_foreach_task = (
DAGTask(self._sanitize(self.graph[node.matching_join].name))
.template(self._sanitize(self.graph[node.matching_join].name))
.dependencies([foreach_template_name])
.arguments(
Arguments().parameters(
(
[
Parameter("input-paths").value(
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
% (node.name, self._sanitize(node.name))
),
Parameter("split-cardinality").value(
"{{tasks.%s.outputs.parameters.split-cardinality}}"
% self._sanitize(node.name)
),
]
if not node.parallel_foreach
else [
Parameter("num-parallel").value(
"{{tasks.%s.outputs.parameters.num-parallel}}"
% self._sanitize(node.name)
),
Parameter("task-id-entropy").value(
"{{tasks.%s.outputs.parameters.task-id-entropy}}"
% self._sanitize(node.name)
),
]
)
+ (
[
Parameter("split-index").value(
# TODO : Pass down these parameters to the jobset stuff.
"{{inputs.parameters.split-index}}"
),
Parameter("root-input-path").value(
"{{inputs.parameters.input-paths}}"
),
]
if parent_foreach
else []
)
)
)
)
dag_tasks.append(join_foreach_task)
return _visit(
self.graph[self.graph[node.matching_join].out_funcs[0]],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For linear nodes continue traversing to the next node
if node.type in ("linear", "join", "start"):
return _visit(
self.graph[node.out_funcs[0]],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
else:
raise ArgoWorkflowsException(
"Node type *%s* for step *%s* is not currently supported by "
"Argo Workflows." % (node.type, node.name)
)
# Generate daemon tasks
daemon_tasks = [
DAGTask("%s-task" % daemon_template.name).template(daemon_template.name)
for daemon_template in self._daemon_templates()
]
templates, _ = _visit(node=self.graph["start"], dag_tasks=daemon_tasks)
return templates