def _container_templates()

in metaflow/plugins/argo/argo_workflows.py [0:0]


    def _container_templates(self):
        try:
            # Kubernetes is a soft dependency for generating Argo objects.
            # We can very well remove this dependency for Argo with the downside of
            # adding a bunch more json bloat classes (looking at you... V1Container)
            from kubernetes import client as kubernetes_sdk
        except (NameError, ImportError):
            raise MetaflowException(
                "Could not import Python package 'kubernetes'. Install kubernetes "
                "sdk (https://pypi.org/project/kubernetes/) first."
            )
        for node in self.graph:
            # Resolve entry point for pod container.
            script_name = os.path.basename(sys.argv[0])
            executable = self.environment.executable(node.name)
            # TODO: Support R someday. Quite a few people will be happy.
            entrypoint = [executable, script_name]

            # The values with curly braces '{{}}' are made available by Argo
            # Workflows. Unfortunately, there are a few bugs in Argo which prevent
            # us from accessing these values as liberally as we would like to - e.g,
            # within inline templates - so we are forced to generate container templates
            run_id = "argo-{{workflow.name}}"

            # Unfortunately, we don't have any easy access to unique ids that remain
            # stable across task attempts through Argo Workflows. So, we are forced to
            # stitch them together ourselves. The task ids are a function of step name,
            # split index and the parent task id (available from input path name).
            # Ideally, we would like these task ids to be the same as node name
            # (modulo retry suffix) on Argo Workflows but that doesn't seem feasible
            # right now.

            task_idx = ""
            input_paths = ""
            root_input = None
            # export input_paths as it is used multiple times in the container script
            # and we do not want to repeat the values.
            input_paths_expr = "export INPUT_PATHS=''"
            # If node is not a start step or a @parallel join then we will set the input paths.
            # To set the input-paths as a parameter, we need to ensure that the node
            # is not (a start node or a parallel join node). Start nodes will have no
            # input paths and parallel join will derive input paths based on a
            # formulaic approach using `num-parallel` and `task-id-entropy`.
            if not (
                node.name == "start"
                or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
            ):
                # For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed.
                # So we don't need to set the input paths.
                input_paths_expr = (
                    "export INPUT_PATHS={{inputs.parameters.input-paths}}"
                )
                input_paths = "$(echo $INPUT_PATHS)"
            if any(self.graph[n].type == "foreach" for n in node.in_funcs):
                task_idx = "{{inputs.parameters.split-index}}"
            if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
                if any(
                    self.graph[parent].matching_join
                    == self.graph[node.out_funcs[0]].name
                    for parent in self.graph[node.out_funcs[0]].split_parents
                    if self.graph[parent].type == "foreach"
                ) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
                    # we need to propagate the split-index and root-input-path info for
                    # the last step inside a foreach for correctly joining nested
                    # foreaches
                    task_idx = "{{inputs.parameters.split-index}}"
                    root_input = "{{inputs.parameters.root-input-path}}"

            # Task string to be hashed into an ID
            task_str = "-".join(
                [
                    node.name,
                    "{{workflow.creationTimestamp}}",
                    root_input or input_paths,
                    task_idx,
                ]
            )
            if node.parallel_step:
                task_str = "-".join(
                    [
                        "$TASK_ID_PREFIX",
                        "{{inputs.parameters.task-id-entropy}}",
                        "$TASK_ID_SUFFIX",
                    ]
                )
            else:
                # Generated task_ids need to be non-numeric - see register_task_id in
                # service.py. We do so by prefixing `t-`
                _task_id_base = (
                    "$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str
                )
                task_str = "(t-%s)" % _task_id_base

            task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str
            task_id = "$METAFLOW_TASK_ID"

            # Resolve retry strategy.
            max_user_code_retries = 0
            max_error_retries = 0
            minutes_between_retries = "2"
            for decorator in node.decorators:
                if decorator.name == "retry":
                    minutes_between_retries = decorator.attributes.get(
                        "minutes_between_retries", minutes_between_retries
                    )
                user_code_retries, error_retries = decorator.step_task_retry_count()
                max_user_code_retries = max(max_user_code_retries, user_code_retries)
                max_error_retries = max(max_error_retries, error_retries)

            user_code_retries = max_user_code_retries
            total_retries = max_user_code_retries + max_error_retries
            # {{retries}} is only available if retryStrategy is specified
            # For custom kubernetes manifests, we will pass the retryCount as a parameter
            # and use that in the manifest.
            retry_count = (
                (
                    "{{retries}}"
                    if not node.parallel_step
                    else "{{inputs.parameters.retryCount}}"
                )
                if total_retries
                else 0
            )

            minutes_between_retries = int(minutes_between_retries)

            # Configure log capture.
            mflog_expr = export_mflog_env_vars(
                datastore_type=self.flow_datastore.TYPE,
                stdout_path="$PWD/.logs/mflog_stdout",
                stderr_path="$PWD/.logs/mflog_stderr",
                flow_name=self.flow.name,
                run_id=run_id,
                step_name=node.name,
                task_id=task_id,
                retry_count=retry_count,
            )

            init_cmds = " && ".join(
                [
                    # For supporting sandboxes, ensure that a custom script is executed
                    # before anything else is executed. The script is passed in as an
                    # env var.
                    '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
                    "mkdir -p $PWD/.logs",
                    input_paths_expr,
                    task_id_expr,
                    mflog_expr,
                ]
                + self.environment.get_package_commands(
                    self.code_package_url, self.flow_datastore.TYPE
                )
            )
            step_cmds = self.environment.bootstrap_commands(
                node.name, self.flow_datastore.TYPE
            )

            top_opts_dict = {
                "with": [
                    decorator.make_decorator_spec()
                    for decorator in node.decorators
                    if not decorator.statically_defined
                ]
            }
            # FlowDecorators can define their own top-level options. They are
            # responsible for adding their own top-level options and values through
            # the get_top_level_options() hook. See similar logic in runtime.py.
            for deco in flow_decorators(self.flow):
                top_opts_dict.update(deco.get_top_level_options())

            top_level = list(dict_to_cli_options(top_opts_dict)) + [
                "--quiet",
                "--metadata=%s" % self.metadata.TYPE,
                "--environment=%s" % self.environment.TYPE,
                "--datastore=%s" % self.flow_datastore.TYPE,
                "--datastore-root=%s" % self.flow_datastore.datastore_root,
                "--event-logger=%s" % self.event_logger.TYPE,
                "--monitor=%s" % self.monitor.TYPE,
                "--no-pylint",
                "--with=argo_workflows_internal:auto-emit-argo-events=%i"
                % self.auto_emit_argo_events,
            ]

            if node.name == "start":
                # Execute `init` before any step of the workflow executes
                task_id_params = "%s-params" % task_id
                init = (
                    entrypoint
                    + top_level
                    + [
                        "init",
                        "--run-id %s" % run_id,
                        "--task-id %s" % task_id_params,
                    ]
                    + [
                        # Parameter names can be hyphenated, hence we use
                        # {{foo.bar['param_name']}}.
                        # https://argoproj.github.io/argo-events/tutorials/02-parameterization/
                        # http://masterminds.github.io/sprig/strings.html
                        "--%s={{workflow.parameters.%s}}"
                        % (parameter["name"], parameter["name"])
                        for parameter in self.parameters.values()
                    ]
                )
                if self.tags:
                    init.extend("--tag %s" % tag for tag in self.tags)
                # if the start step gets retried, we must be careful
                # not to regenerate multiple parameters tasks. Hence,
                # we check first if _parameters exists already.
                exists = entrypoint + [
                    "dump",
                    "--max-value-size=0",
                    "%s/_parameters/%s" % (run_id, task_id_params),
                ]
                step_cmds.extend(
                    [
                        "if ! %s >/dev/null 2>/dev/null; then %s; fi"
                        % (" ".join(exists), " ".join(init))
                    ]
                )
                input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
            elif (
                node.type == "join"
                and self.graph[node.split_parents[-1]].type == "foreach"
            ):
                # Set aggregated input-paths for a for-each join
                foreach_step = next(
                    n for n in node.in_funcs if self.graph[n].is_inside_foreach
                )
                if not self.graph[node.split_parents[-1]].parallel_foreach:
                    input_paths = (
                        "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
                        % (
                            foreach_step,
                            input_paths,
                        )
                    )
                else:
                    # Handle @parallel where output from volume mount isn't accessible
                    input_paths = (
                        "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
                        % (
                            run_id,
                            foreach_step,
                        )
                    )
            step = [
                "step",
                node.name,
                "--run-id %s" % run_id,
                "--task-id %s" % task_id,
                "--retry-count %s" % retry_count,
                "--max-user-code-retries %d" % user_code_retries,
                "--input-paths %s" % input_paths,
            ]
            if node.parallel_step:
                step.append(
                    "--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}"
                )
                # This is needed for setting the value of the UBF context in the CLI.
                step.append("--ubf-context $UBF_CONTEXT")

            elif any(self.graph[n].type == "foreach" for n in node.in_funcs):
                # Pass split-index to a foreach task
                step.append("--split-index {{inputs.parameters.split-index}}")
            if self.tags:
                step.extend("--tag %s" % tag for tag in self.tags)
            if self.namespace is not None:
                step.append("--namespace=%s" % self.namespace)

            step_cmds.extend([" ".join(entrypoint + top_level + step)])

            cmd_str = "%s; c=$?; %s; exit $c" % (
                " && ".join([init_cmds, bash_capture_logs(" && ".join(step_cmds))]),
                BASH_SAVE_LOGS,
            )
            cmds = shlex.split('bash -c "%s"' % cmd_str)

            # Resolve resource requirements.
            resources = dict(
                [deco for deco in node.decorators if deco.name == "kubernetes"][
                    0
                ].attributes
            )

            if (
                resources["namespace"]
                and resources["namespace"] != KUBERNETES_NAMESPACE
            ):
                raise ArgoWorkflowsException(
                    "Multi-namespace Kubernetes execution of flows in Argo Workflows "
                    "is not currently supported. \nStep *%s* is trying to override "
                    "the default Kubernetes namespace *%s*."
                    % (node.name, KUBERNETES_NAMESPACE)
                )

            run_time_limit = [
                deco for deco in node.decorators if deco.name == "kubernetes"
            ][0].run_time_limit

            # Resolve @environment decorator. We set three classes of environment
            # variables -
            #   (1) User-specified environment variables through @environment
            #   (2) Metaflow runtime specific environment variables
            #   (3) @kubernetes, @argo_workflows_internal bookkeeping environment
            #       variables
            env = dict(
                [deco for deco in node.decorators if deco.name == "environment"][
                    0
                ].attributes["vars"]
            )

            # Temporary passing of *some* environment variables. Do not rely on this
            # mechanism as it will be removed in the near future
            env.update(
                {
                    k: v
                    for k, v in config_values()
                    if k.startswith("METAFLOW_CONDA_")
                    or k.startswith("METAFLOW_DEBUG_")
                }
            )

            env.update(
                {
                    **{
                        # These values are needed by Metaflow to set it's internal
                        # state appropriately.
                        "METAFLOW_CODE_URL": self.code_package_url,
                        "METAFLOW_CODE_SHA": self.code_package_sha,
                        "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
                        "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
                        "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
                        "METAFLOW_USER": "argo-workflows",
                        "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
                        "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
                        "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
                        "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
                        "METAFLOW_CARD_S3ROOT": CARD_S3ROOT,
                        "METAFLOW_KUBERNETES_WORKLOAD": 1,
                        "METAFLOW_KUBERNETES_FETCH_EC2_METADATA": KUBERNETES_FETCH_EC2_METADATA,
                        "METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
                        "METAFLOW_OWNER": self.username,
                    },
                    **{
                        # Configuration for Argo Events. Keep these in sync with the
                        # environment variables for @kubernetes decorator.
                        "METAFLOW_ARGO_EVENTS_EVENT": ARGO_EVENTS_EVENT,
                        "METAFLOW_ARGO_EVENTS_EVENT_BUS": ARGO_EVENTS_EVENT_BUS,
                        "METAFLOW_ARGO_EVENTS_EVENT_SOURCE": ARGO_EVENTS_EVENT_SOURCE,
                        "METAFLOW_ARGO_EVENTS_SERVICE_ACCOUNT": ARGO_EVENTS_SERVICE_ACCOUNT,
                        "METAFLOW_ARGO_EVENTS_WEBHOOK_URL": ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
                        "METAFLOW_ARGO_EVENTS_WEBHOOK_AUTH": ARGO_EVENTS_WEBHOOK_AUTH,
                    },
                    **{
                        # Some optional values for bookkeeping
                        "METAFLOW_FLOW_FILENAME": os.path.basename(sys.argv[0]),
                        "METAFLOW_FLOW_NAME": self.flow.name,
                        "METAFLOW_STEP_NAME": node.name,
                        "METAFLOW_RUN_ID": run_id,
                        # "METAFLOW_TASK_ID": task_id,
                        "METAFLOW_RETRY_COUNT": retry_count,
                        "METAFLOW_PRODUCTION_TOKEN": self.production_token,
                        "ARGO_WORKFLOW_TEMPLATE": self.name,
                        "ARGO_WORKFLOW_NAME": "{{workflow.name}}",
                        "ARGO_WORKFLOW_NAMESPACE": KUBERNETES_NAMESPACE,
                    },
                    **self.metadata.get_runtime_environment("argo-workflows"),
                }
            )
            # add METAFLOW_S3_ENDPOINT_URL
            env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL

            # support Metaflow sandboxes
            env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
            env["METAFLOW_KUBERNETES_SANDBOX_INIT_SCRIPT"] = (
                KUBERNETES_SANDBOX_INIT_SCRIPT
            )

            # support for @secret
            env["METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE"] = DEFAULT_SECRETS_BACKEND_TYPE
            env["METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"] = (
                AWS_SECRETS_MANAGER_DEFAULT_REGION
            )
            env["METAFLOW_GCP_SECRET_MANAGER_PREFIX"] = GCP_SECRET_MANAGER_PREFIX
            env["METAFLOW_AZURE_KEY_VAULT_PREFIX"] = AZURE_KEY_VAULT_PREFIX

            # support for Azure
            env["METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"] = (
                AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
            )
            env["METAFLOW_DATASTORE_SYSROOT_AZURE"] = DATASTORE_SYSROOT_AZURE
            env["METAFLOW_CARD_AZUREROOT"] = CARD_AZUREROOT
            env["METAFLOW_ARGO_WORKFLOWS_KUBERNETES_SECRETS"] = (
                ARGO_WORKFLOWS_KUBERNETES_SECRETS
            )
            env["METAFLOW_ARGO_WORKFLOWS_ENV_VARS_TO_SKIP"] = (
                ARGO_WORKFLOWS_ENV_VARS_TO_SKIP
            )

            # support for GCP
            env["METAFLOW_DATASTORE_SYSROOT_GS"] = DATASTORE_SYSROOT_GS
            env["METAFLOW_CARD_GSROOT"] = CARD_GSROOT

            # Map Argo Events payload (if any) to environment variables
            if self.triggers:
                for event in self.triggers:
                    env[
                        "METAFLOW_ARGO_EVENT_PAYLOAD_%s_%s"
                        % (event["type"], event["sanitized_name"])
                    ] = ("{{workflow.parameters.%s}}" % event["sanitized_name"])

            # Map S3 upload headers to environment variables
            if S3_SERVER_SIDE_ENCRYPTION is not None:
                env["METAFLOW_S3_SERVER_SIDE_ENCRYPTION"] = S3_SERVER_SIDE_ENCRYPTION

            metaflow_version = self.environment.get_environment_info()
            metaflow_version["flow_name"] = self.graph.name
            metaflow_version["production_token"] = self.production_token
            env["METAFLOW_VERSION"] = json.dumps(metaflow_version)

            # map config values
            cfg_env = {
                param["name"]: param["kv_name"] for param in self.config_parameters
            }
            if cfg_env:
                env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(cfg_env)

            # Set the template inputs and outputs for passing state. Very simply,
            # the container template takes in input-paths as input and outputs
            # the task-id (which feeds in as input-paths to the subsequent task).
            # In addition to that, if the parent of the node under consideration
            # is a for-each node, then we take the split-index as an additional
            # input. Analogously, if the node under consideration is a foreach
            # node, then we emit split cardinality as an extra output. I would like
            # to thank the designers of Argo Workflows for making this so
            # straightforward! Things become a bit more complicated to support very
            # wide foreaches where we have to resort to passing a root-input-path
            # so that we can compute the task ids for each parent task of a for-each
            # join task deterministically inside the join task without resorting to
            # passing a rather long list of (albiet compressed)
            inputs = []
            # To set the input-paths as a parameter, we need to ensure that the node
            # is not (a start node or a parallel join node). Start nodes will have no
            # input paths and parallel join will derive input paths based on a
            # formulaic approach.
            if not (
                node.name == "start"
                or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
            ):
                inputs.append(Parameter("input-paths"))
            if any(self.graph[n].type == "foreach" for n in node.in_funcs):
                # Fetch split-index from parent
                inputs.append(Parameter("split-index"))

            if (
                node.type == "join"
                and self.graph[node.split_parents[-1]].type == "foreach"
            ):
                # @parallel join tasks require `num-parallel` and `task-id-entropy`
                # to construct the input paths, so we pass them down as input parameters.
                if self.graph[node.split_parents[-1]].parallel_foreach:
                    inputs.extend(
                        [Parameter("num-parallel"), Parameter("task-id-entropy")]
                    )
                else:
                    # append this only for joins of foreaches, not static splits
                    inputs.append(Parameter("split-cardinality"))
            # check if the node is a @parallel node.
            elif node.parallel_step:
                inputs.extend(
                    [
                        Parameter("num-parallel"),
                        Parameter("task-id-entropy"),
                        Parameter("jobset-name"),
                        Parameter("workerCount"),
                    ]
                )
                if any(d.name == "retry" for d in node.decorators):
                    inputs.append(Parameter("retryCount"))

            if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
                if any(
                    self.graph[parent].matching_join
                    == self.graph[node.out_funcs[0]].name
                    for parent in self.graph[node.out_funcs[0]].split_parents
                    if self.graph[parent].type == "foreach"
                ) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
                    # we need to propagate the split-index and root-input-path info for
                    # the last step inside a foreach for correctly joining nested
                    # foreaches
                    if not any(self.graph[n].type == "foreach" for n in node.in_funcs):
                        # Don't add duplicate split index parameters.
                        inputs.append(Parameter("split-index"))
                    inputs.append(Parameter("root-input-path"))

            outputs = []
            # @parallel steps will not have a task-id as an output parameter since task-ids
            # are derived at runtime.
            if not (node.name == "end" or node.parallel_step):
                outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
            if node.type == "foreach":
                # Emit split cardinality from foreach task
                outputs.append(
                    Parameter("num-splits").valueFrom({"path": "/mnt/out/splits"})
                )
                outputs.append(
                    Parameter("split-cardinality").valueFrom(
                        {"path": "/mnt/out/split_cardinality"}
                    )
                )

            if node.parallel_foreach:
                outputs.extend(
                    [
                        Parameter("num-parallel").valueFrom(
                            {"path": "/mnt/out/num_parallel"}
                        ),
                        Parameter("task-id-entropy").valueFrom(
                            {"path": "/mnt/out/task_id_entropy"}
                        ),
                    ]
                )
            # Outputs should be defined over here and not in the _dag_template for @parallel.

            # It makes no sense to set env vars to None (shows up as "None" string)
            # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
            env = {
                k: v
                for k, v in env.items()
                if v is not None
                and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
            }

            # Tmpfs variables
            use_tmpfs = resources["use_tmpfs"]
            tmpfs_size = resources["tmpfs_size"]
            tmpfs_path = resources["tmpfs_path"]
            tmpfs_tempdir = resources["tmpfs_tempdir"]
            # Set shared_memory to 0 if it isn't specified. This results
            # in Kubernetes using it's default value when the pod is created.
            shared_memory = resources.get("shared_memory", 0)
            port = resources.get("port", None)
            if port:
                port = int(port)

            tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)

            if tmpfs_enabled and tmpfs_tempdir:
                env["METAFLOW_TEMPDIR"] = tmpfs_path

            qos_requests, qos_limits = qos_requests_and_limits(
                resources["qos"],
                resources["cpu"],
                resources["memory"],
                resources["disk"],
            )

            security_context = resources.get("security_context", None)
            _security_context = {}
            if security_context is not None and len(security_context) > 0:
                _security_context = {
                    "security_context": kubernetes_sdk.V1SecurityContext(
                        **security_context
                    )
                }

            # Create a ContainerTemplate for this node. Ideally, we would have
            # liked to inline this ContainerTemplate and avoid scanning the workflow
            # twice, but due to issues with variable substitution, we will have to
            # live with this routine.
            if node.parallel_step:
                jobset_name = "{{inputs.parameters.jobset-name}}"
                jobset = KubernetesArgoJobSet(
                    kubernetes_sdk=kubernetes_sdk,
                    name=jobset_name,
                    flow_name=self.flow.name,
                    run_id=run_id,
                    step_name=self._sanitize(node.name),
                    task_id=task_id,
                    attempt=retry_count,
                    user=self.username,
                    subdomain=jobset_name,
                    command=cmds,
                    namespace=resources["namespace"],
                    image=resources["image"],
                    image_pull_policy=resources["image_pull_policy"],
                    service_account=resources["service_account"],
                    secrets=(
                        [
                            k
                            for k in (
                                list(
                                    []
                                    if not resources.get("secrets")
                                    else (
                                        [resources.get("secrets")]
                                        if isinstance(resources.get("secrets"), str)
                                        else resources.get("secrets")
                                    )
                                )
                                + KUBERNETES_SECRETS.split(",")
                                + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
                            )
                            if k
                        ]
                    ),
                    node_selector=resources.get("node_selector"),
                    cpu=str(resources["cpu"]),
                    memory=str(resources["memory"]),
                    disk=str(resources["disk"]),
                    gpu=resources["gpu"],
                    gpu_vendor=str(resources["gpu_vendor"]),
                    tolerations=resources["tolerations"],
                    use_tmpfs=use_tmpfs,
                    tmpfs_tempdir=tmpfs_tempdir,
                    tmpfs_size=tmpfs_size,
                    tmpfs_path=tmpfs_path,
                    timeout_in_seconds=run_time_limit,
                    persistent_volume_claims=resources["persistent_volume_claims"],
                    shared_memory=shared_memory,
                    port=port,
                    qos=resources["qos"],
                    security_context=security_context,
                )

                for k, v in env.items():
                    jobset.environment_variable(k, v)

                # Set labels. Do not allow user-specified task labels to override internal ones.
                #
                # Explicitly add the task-id-hint label. This is important because this label
                # is returned as an Output parameter of this step and is used subsequently as an
                # an input in the join step.
                kubernetes_labels = {
                    "task_id_entropy": "{{inputs.parameters.task-id-entropy}}",
                    "num_parallel": "{{inputs.parameters.num-parallel}}",
                    "metaflow/argo-workflows-name": "{{workflow.name}}",
                    "workflows.argoproj.io/workflow": "{{workflow.name}}",
                }
                jobset.labels(
                    {
                        **resources["labels"],
                        **self._base_labels,
                        **kubernetes_labels,
                    }
                )

                jobset.environment_variable(
                    "MF_MASTER_ADDR", jobset.jobset_control_addr
                )
                jobset.environment_variable("MF_MASTER_PORT", str(port))
                jobset.environment_variable(
                    "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
                )
                # We need this task-id set so that all the nodes are aware of the control
                # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
                jobset.environment_variable(
                    "MF_PARALLEL_CONTROL_TASK_ID",
                    "control-{{inputs.parameters.task-id-entropy}}-0",
                )
                # for k, v in .items():
                jobset.environment_variables_from_selectors(
                    {
                        "MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
                        "JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
                        "METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
                        "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
                        "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
                        "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
                        "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
                        "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
                        "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
                    }
                )

                # Set annotations. Do not allow user-specified task-specific annotations to override internal ones.
                annotations = {
                    # setting annotations explicitly as they wont be
                    # passed down from WorkflowTemplate level
                    "metaflow/step_name": node.name,
                    "metaflow/attempt": str(retry_count),
                    "metaflow/run_id": run_id,
                }

                jobset.annotations(
                    {
                        **resources["annotations"],
                        **self._base_annotations,
                        **annotations,
                    }
                )

                jobset.control.replicas(1)
                jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
                jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
                jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
                jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
                # `TASK_ID_PREFIX` needs to explicitly be `control` or `worker`
                # because the join task uses a formulaic approach to infer the task-ids
                jobset.control.environment_variable("TASK_ID_PREFIX", "control")
                jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")

                yield (
                    Template(ArgoWorkflows._sanitize(node.name))
                    .resource(
                        "create",
                        jobset.dump(),
                        "status.terminalState == Completed",
                        "status.terminalState == Failed",
                    )
                    .inputs(Inputs().parameters(inputs))
                    .outputs(
                        Outputs().parameters(
                            [
                                Parameter("task-id-entropy").valueFrom(
                                    {"jsonPath": "{.metadata.labels.task_id_entropy}"}
                                ),
                                Parameter("num-parallel").valueFrom(
                                    {"jsonPath": "{.metadata.labels.num_parallel}"}
                                ),
                            ]
                        )
                    )
                    .retry_strategy(
                        times=total_retries,
                        minutes_between_retries=minutes_between_retries,
                    )
                )
            else:
                yield (
                    Template(self._sanitize(node.name))
                    # Set @timeout values
                    .active_deadline_seconds(run_time_limit)
                    # Set service account
                    .service_account_name(resources["service_account"])
                    # Configure template input
                    .inputs(Inputs().parameters(inputs))
                    # Configure template output
                    .outputs(Outputs().parameters(outputs))
                    # Fail fast!
                    .fail_fast()
                    # Set @retry/@catch values
                    .retry_strategy(
                        times=total_retries,
                        minutes_between_retries=minutes_between_retries,
                    )
                    .metadata(
                        ObjectMeta()
                        .annotation("metaflow/step_name", node.name)
                        # Unfortunately, we can't set the task_id since it is generated
                        # inside the pod. However, it can be inferred from the annotation
                        # set by argo-workflows - `workflows.argoproj.io/outputs` - refer
                        # the field 'task-id' in 'parameters'
                        # .annotation("metaflow/task_id", ...)
                        .annotation("metaflow/attempt", retry_count)
                        .annotations(resources["annotations"])
                        .labels(resources["labels"])
                    )
                    # Set emptyDir volume for state management
                    .empty_dir_volume("out")
                    # Set tmpfs emptyDir volume if enabled
                    .empty_dir_volume(
                        "tmpfs-ephemeral-volume",
                        medium="Memory",
                        size_limit=tmpfs_size if tmpfs_enabled else 0,
                    )
                    .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
                    .pvc_volumes(resources.get("persistent_volume_claims"))
                    # Set node selectors
                    .node_selectors(resources.get("node_selector"))
                    # Set tolerations
                    .tolerations(resources.get("tolerations"))
                    # Set container
                    .container(
                        # TODO: Unify the logic with kubernetes.py
                        # Important note - Unfortunately, V1Container uses snakecase while
                        # Argo Workflows uses camel. For most of the attributes, both cases
                        # are indistinguishable, but unfortunately, not for all - (
                        # env_from, value_from, etc.) - so we need to handle the conversion
                        # ourselves using to_camelcase. We need to be vigilant about
                        # resources attributes in particular where the keys maybe user
                        # defined.
                        to_camelcase(
                            kubernetes_sdk.V1Container(
                                name=self._sanitize(node.name),
                                command=cmds,
                                termination_message_policy="FallbackToLogsOnError",
                                ports=(
                                    [
                                        kubernetes_sdk.V1ContainerPort(
                                            container_port=port
                                        )
                                    ]
                                    if port
                                    else None
                                ),
                                env=[
                                    kubernetes_sdk.V1EnvVar(name=k, value=str(v))
                                    for k, v in env.items()
                                ]
                                # Add environment variables for book-keeping.
                                # https://argoproj.github.io/argo-workflows/fields/#fields_155
                                + [
                                    kubernetes_sdk.V1EnvVar(
                                        name=k,
                                        value_from=kubernetes_sdk.V1EnvVarSource(
                                            field_ref=kubernetes_sdk.V1ObjectFieldSelector(
                                                field_path=str(v)
                                            )
                                        ),
                                    )
                                    for k, v in {
                                        "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
                                        "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
                                        "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
                                        "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
                                        "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
                                        "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
                                    }.items()
                                ],
                                image=resources["image"],
                                image_pull_policy=resources["image_pull_policy"],
                                resources=kubernetes_sdk.V1ResourceRequirements(
                                    requests=qos_requests,
                                    limits={
                                        **qos_limits,
                                        **{
                                            "%s.com/gpu".lower()
                                            % resources["gpu_vendor"]: str(
                                                resources["gpu"]
                                            )
                                            for k in [0]
                                            if resources["gpu"] is not None
                                        },
                                    },
                                ),
                                # Configure secrets
                                env_from=[
                                    kubernetes_sdk.V1EnvFromSource(
                                        secret_ref=kubernetes_sdk.V1SecretEnvSource(
                                            name=str(k),
                                            # optional=True
                                        )
                                    )
                                    for k in list(
                                        []
                                        if not resources.get("secrets")
                                        else (
                                            [resources.get("secrets")]
                                            if isinstance(resources.get("secrets"), str)
                                            else resources.get("secrets")
                                        )
                                    )
                                    + KUBERNETES_SECRETS.split(",")
                                    + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
                                    if k
                                ],
                                volume_mounts=[
                                    # Assign a volume mount to pass state to the next task.
                                    kubernetes_sdk.V1VolumeMount(
                                        name="out", mount_path="/mnt/out"
                                    )
                                ]
                                # Support tmpfs.
                                + (
                                    [
                                        kubernetes_sdk.V1VolumeMount(
                                            name="tmpfs-ephemeral-volume",
                                            mount_path=tmpfs_path,
                                        )
                                    ]
                                    if tmpfs_enabled
                                    else []
                                )
                                # Support shared_memory
                                + (
                                    [
                                        kubernetes_sdk.V1VolumeMount(
                                            name="dhsm",
                                            mount_path="/dev/shm",
                                        )
                                    ]
                                    if shared_memory
                                    else []
                                )
                                # Support persistent volume claims.
                                + (
                                    [
                                        kubernetes_sdk.V1VolumeMount(
                                            name=claim, mount_path=path
                                        )
                                        for claim, path in resources.get(
                                            "persistent_volume_claims"
                                        ).items()
                                    ]
                                    if resources.get("persistent_volume_claims")
                                    is not None
                                    else []
                                ),
                                **_security_context,
                            ).to_dict()
                        )
                    )
                )