in metaflow/plugins/argo/argo_workflows.py [0:0]
def _container_templates(self):
try:
# Kubernetes is a soft dependency for generating Argo objects.
# We can very well remove this dependency for Argo with the downside of
# adding a bunch more json bloat classes (looking at you... V1Container)
from kubernetes import client as kubernetes_sdk
except (NameError, ImportError):
raise MetaflowException(
"Could not import Python package 'kubernetes'. Install kubernetes "
"sdk (https://pypi.org/project/kubernetes/) first."
)
for node in self.graph:
# Resolve entry point for pod container.
script_name = os.path.basename(sys.argv[0])
executable = self.environment.executable(node.name)
# TODO: Support R someday. Quite a few people will be happy.
entrypoint = [executable, script_name]
# The values with curly braces '{{}}' are made available by Argo
# Workflows. Unfortunately, there are a few bugs in Argo which prevent
# us from accessing these values as liberally as we would like to - e.g,
# within inline templates - so we are forced to generate container templates
run_id = "argo-{{workflow.name}}"
# Unfortunately, we don't have any easy access to unique ids that remain
# stable across task attempts through Argo Workflows. So, we are forced to
# stitch them together ourselves. The task ids are a function of step name,
# split index and the parent task id (available from input path name).
# Ideally, we would like these task ids to be the same as node name
# (modulo retry suffix) on Argo Workflows but that doesn't seem feasible
# right now.
task_idx = ""
input_paths = ""
root_input = None
# export input_paths as it is used multiple times in the container script
# and we do not want to repeat the values.
input_paths_expr = "export INPUT_PATHS=''"
# If node is not a start step or a @parallel join then we will set the input paths.
# To set the input-paths as a parameter, we need to ensure that the node
# is not (a start node or a parallel join node). Start nodes will have no
# input paths and parallel join will derive input paths based on a
# formulaic approach using `num-parallel` and `task-id-entropy`.
if not (
node.name == "start"
or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
):
# For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed.
# So we don't need to set the input paths.
input_paths_expr = (
"export INPUT_PATHS={{inputs.parameters.input-paths}}"
)
input_paths = "$(echo $INPUT_PATHS)"
if any(self.graph[n].type == "foreach" for n in node.in_funcs):
task_idx = "{{inputs.parameters.split-index}}"
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
for parent in self.graph[node.out_funcs[0]].split_parents
if self.graph[parent].type == "foreach"
) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
# we need to propagate the split-index and root-input-path info for
# the last step inside a foreach for correctly joining nested
# foreaches
task_idx = "{{inputs.parameters.split-index}}"
root_input = "{{inputs.parameters.root-input-path}}"
# Task string to be hashed into an ID
task_str = "-".join(
[
node.name,
"{{workflow.creationTimestamp}}",
root_input or input_paths,
task_idx,
]
)
if node.parallel_step:
task_str = "-".join(
[
"$TASK_ID_PREFIX",
"{{inputs.parameters.task-id-entropy}}",
"$TASK_ID_SUFFIX",
]
)
else:
# Generated task_ids need to be non-numeric - see register_task_id in
# service.py. We do so by prefixing `t-`
_task_id_base = (
"$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str
)
task_str = "(t-%s)" % _task_id_base
task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str
task_id = "$METAFLOW_TASK_ID"
# Resolve retry strategy.
max_user_code_retries = 0
max_error_retries = 0
minutes_between_retries = "2"
for decorator in node.decorators:
if decorator.name == "retry":
minutes_between_retries = decorator.attributes.get(
"minutes_between_retries", minutes_between_retries
)
user_code_retries, error_retries = decorator.step_task_retry_count()
max_user_code_retries = max(max_user_code_retries, user_code_retries)
max_error_retries = max(max_error_retries, error_retries)
user_code_retries = max_user_code_retries
total_retries = max_user_code_retries + max_error_retries
# {{retries}} is only available if retryStrategy is specified
# For custom kubernetes manifests, we will pass the retryCount as a parameter
# and use that in the manifest.
retry_count = (
(
"{{retries}}"
if not node.parallel_step
else "{{inputs.parameters.retryCount}}"
)
if total_retries
else 0
)
minutes_between_retries = int(minutes_between_retries)
# Configure log capture.
mflog_expr = export_mflog_env_vars(
datastore_type=self.flow_datastore.TYPE,
stdout_path="$PWD/.logs/mflog_stdout",
stderr_path="$PWD/.logs/mflog_stderr",
flow_name=self.flow.name,
run_id=run_id,
step_name=node.name,
task_id=task_id,
retry_count=retry_count,
)
init_cmds = " && ".join(
[
# For supporting sandboxes, ensure that a custom script is executed
# before anything else is executed. The script is passed in as an
# env var.
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
"mkdir -p $PWD/.logs",
input_paths_expr,
task_id_expr,
mflog_expr,
]
+ self.environment.get_package_commands(
self.code_package_url, self.flow_datastore.TYPE
)
)
step_cmds = self.environment.bootstrap_commands(
node.name, self.flow_datastore.TYPE
)
top_opts_dict = {
"with": [
decorator.make_decorator_spec()
for decorator in node.decorators
if not decorator.statically_defined
]
}
# FlowDecorators can define their own top-level options. They are
# responsible for adding their own top-level options and values through
# the get_top_level_options() hook. See similar logic in runtime.py.
for deco in flow_decorators(self.flow):
top_opts_dict.update(deco.get_top_level_options())
top_level = list(dict_to_cli_options(top_opts_dict)) + [
"--quiet",
"--metadata=%s" % self.metadata.TYPE,
"--environment=%s" % self.environment.TYPE,
"--datastore=%s" % self.flow_datastore.TYPE,
"--datastore-root=%s" % self.flow_datastore.datastore_root,
"--event-logger=%s" % self.event_logger.TYPE,
"--monitor=%s" % self.monitor.TYPE,
"--no-pylint",
"--with=argo_workflows_internal:auto-emit-argo-events=%i"
% self.auto_emit_argo_events,
]
if node.name == "start":
# Execute `init` before any step of the workflow executes
task_id_params = "%s-params" % task_id
init = (
entrypoint
+ top_level
+ [
"init",
"--run-id %s" % run_id,
"--task-id %s" % task_id_params,
]
+ [
# Parameter names can be hyphenated, hence we use
# {{foo.bar['param_name']}}.
# https://argoproj.github.io/argo-events/tutorials/02-parameterization/
# http://masterminds.github.io/sprig/strings.html
"--%s={{workflow.parameters.%s}}"
% (parameter["name"], parameter["name"])
for parameter in self.parameters.values()
]
)
if self.tags:
init.extend("--tag %s" % tag for tag in self.tags)
# if the start step gets retried, we must be careful
# not to regenerate multiple parameters tasks. Hence,
# we check first if _parameters exists already.
exists = entrypoint + [
"dump",
"--max-value-size=0",
"%s/_parameters/%s" % (run_id, task_id_params),
]
step_cmds.extend(
[
"if ! %s >/dev/null 2>/dev/null; then %s; fi"
% (" ".join(exists), " ".join(init))
]
)
input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
elif (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
# Set aggregated input-paths for a for-each join
foreach_step = next(
n for n in node.in_funcs if self.graph[n].is_inside_foreach
)
if not self.graph[node.split_parents[-1]].parallel_foreach:
input_paths = (
"$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
% (
foreach_step,
input_paths,
)
)
else:
# Handle @parallel where output from volume mount isn't accessible
input_paths = (
"$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
% (
run_id,
foreach_step,
)
)
step = [
"step",
node.name,
"--run-id %s" % run_id,
"--task-id %s" % task_id,
"--retry-count %s" % retry_count,
"--max-user-code-retries %d" % user_code_retries,
"--input-paths %s" % input_paths,
]
if node.parallel_step:
step.append(
"--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}"
)
# This is needed for setting the value of the UBF context in the CLI.
step.append("--ubf-context $UBF_CONTEXT")
elif any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Pass split-index to a foreach task
step.append("--split-index {{inputs.parameters.split-index}}")
if self.tags:
step.extend("--tag %s" % tag for tag in self.tags)
if self.namespace is not None:
step.append("--namespace=%s" % self.namespace)
step_cmds.extend([" ".join(entrypoint + top_level + step)])
cmd_str = "%s; c=$?; %s; exit $c" % (
" && ".join([init_cmds, bash_capture_logs(" && ".join(step_cmds))]),
BASH_SAVE_LOGS,
)
cmds = shlex.split('bash -c "%s"' % cmd_str)
# Resolve resource requirements.
resources = dict(
[deco for deco in node.decorators if deco.name == "kubernetes"][
0
].attributes
)
if (
resources["namespace"]
and resources["namespace"] != KUBERNETES_NAMESPACE
):
raise ArgoWorkflowsException(
"Multi-namespace Kubernetes execution of flows in Argo Workflows "
"is not currently supported. \nStep *%s* is trying to override "
"the default Kubernetes namespace *%s*."
% (node.name, KUBERNETES_NAMESPACE)
)
run_time_limit = [
deco for deco in node.decorators if deco.name == "kubernetes"
][0].run_time_limit
# Resolve @environment decorator. We set three classes of environment
# variables -
# (1) User-specified environment variables through @environment
# (2) Metaflow runtime specific environment variables
# (3) @kubernetes, @argo_workflows_internal bookkeeping environment
# variables
env = dict(
[deco for deco in node.decorators if deco.name == "environment"][
0
].attributes["vars"]
)
# Temporary passing of *some* environment variables. Do not rely on this
# mechanism as it will be removed in the near future
env.update(
{
k: v
for k, v in config_values()
if k.startswith("METAFLOW_CONDA_")
or k.startswith("METAFLOW_DEBUG_")
}
)
env.update(
{
**{
# These values are needed by Metaflow to set it's internal
# state appropriately.
"METAFLOW_CODE_URL": self.code_package_url,
"METAFLOW_CODE_SHA": self.code_package_sha,
"METAFLOW_CODE_DS": self.flow_datastore.TYPE,
"METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
"METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
"METAFLOW_USER": "argo-workflows",
"METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
"METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
"METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
"METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
"METAFLOW_CARD_S3ROOT": CARD_S3ROOT,
"METAFLOW_KUBERNETES_WORKLOAD": 1,
"METAFLOW_KUBERNETES_FETCH_EC2_METADATA": KUBERNETES_FETCH_EC2_METADATA,
"METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
"METAFLOW_OWNER": self.username,
},
**{
# Configuration for Argo Events. Keep these in sync with the
# environment variables for @kubernetes decorator.
"METAFLOW_ARGO_EVENTS_EVENT": ARGO_EVENTS_EVENT,
"METAFLOW_ARGO_EVENTS_EVENT_BUS": ARGO_EVENTS_EVENT_BUS,
"METAFLOW_ARGO_EVENTS_EVENT_SOURCE": ARGO_EVENTS_EVENT_SOURCE,
"METAFLOW_ARGO_EVENTS_SERVICE_ACCOUNT": ARGO_EVENTS_SERVICE_ACCOUNT,
"METAFLOW_ARGO_EVENTS_WEBHOOK_URL": ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
"METAFLOW_ARGO_EVENTS_WEBHOOK_AUTH": ARGO_EVENTS_WEBHOOK_AUTH,
},
**{
# Some optional values for bookkeeping
"METAFLOW_FLOW_FILENAME": os.path.basename(sys.argv[0]),
"METAFLOW_FLOW_NAME": self.flow.name,
"METAFLOW_STEP_NAME": node.name,
"METAFLOW_RUN_ID": run_id,
# "METAFLOW_TASK_ID": task_id,
"METAFLOW_RETRY_COUNT": retry_count,
"METAFLOW_PRODUCTION_TOKEN": self.production_token,
"ARGO_WORKFLOW_TEMPLATE": self.name,
"ARGO_WORKFLOW_NAME": "{{workflow.name}}",
"ARGO_WORKFLOW_NAMESPACE": KUBERNETES_NAMESPACE,
},
**self.metadata.get_runtime_environment("argo-workflows"),
}
)
# add METAFLOW_S3_ENDPOINT_URL
env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL
# support Metaflow sandboxes
env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
env["METAFLOW_KUBERNETES_SANDBOX_INIT_SCRIPT"] = (
KUBERNETES_SANDBOX_INIT_SCRIPT
)
# support for @secret
env["METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE"] = DEFAULT_SECRETS_BACKEND_TYPE
env["METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"] = (
AWS_SECRETS_MANAGER_DEFAULT_REGION
)
env["METAFLOW_GCP_SECRET_MANAGER_PREFIX"] = GCP_SECRET_MANAGER_PREFIX
env["METAFLOW_AZURE_KEY_VAULT_PREFIX"] = AZURE_KEY_VAULT_PREFIX
# support for Azure
env["METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"] = (
AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
)
env["METAFLOW_DATASTORE_SYSROOT_AZURE"] = DATASTORE_SYSROOT_AZURE
env["METAFLOW_CARD_AZUREROOT"] = CARD_AZUREROOT
env["METAFLOW_ARGO_WORKFLOWS_KUBERNETES_SECRETS"] = (
ARGO_WORKFLOWS_KUBERNETES_SECRETS
)
env["METAFLOW_ARGO_WORKFLOWS_ENV_VARS_TO_SKIP"] = (
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP
)
# support for GCP
env["METAFLOW_DATASTORE_SYSROOT_GS"] = DATASTORE_SYSROOT_GS
env["METAFLOW_CARD_GSROOT"] = CARD_GSROOT
# Map Argo Events payload (if any) to environment variables
if self.triggers:
for event in self.triggers:
env[
"METAFLOW_ARGO_EVENT_PAYLOAD_%s_%s"
% (event["type"], event["sanitized_name"])
] = ("{{workflow.parameters.%s}}" % event["sanitized_name"])
# Map S3 upload headers to environment variables
if S3_SERVER_SIDE_ENCRYPTION is not None:
env["METAFLOW_S3_SERVER_SIDE_ENCRYPTION"] = S3_SERVER_SIDE_ENCRYPTION
metaflow_version = self.environment.get_environment_info()
metaflow_version["flow_name"] = self.graph.name
metaflow_version["production_token"] = self.production_token
env["METAFLOW_VERSION"] = json.dumps(metaflow_version)
# map config values
cfg_env = {
param["name"]: param["kv_name"] for param in self.config_parameters
}
if cfg_env:
env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(cfg_env)
# Set the template inputs and outputs for passing state. Very simply,
# the container template takes in input-paths as input and outputs
# the task-id (which feeds in as input-paths to the subsequent task).
# In addition to that, if the parent of the node under consideration
# is a for-each node, then we take the split-index as an additional
# input. Analogously, if the node under consideration is a foreach
# node, then we emit split cardinality as an extra output. I would like
# to thank the designers of Argo Workflows for making this so
# straightforward! Things become a bit more complicated to support very
# wide foreaches where we have to resort to passing a root-input-path
# so that we can compute the task ids for each parent task of a for-each
# join task deterministically inside the join task without resorting to
# passing a rather long list of (albiet compressed)
inputs = []
# To set the input-paths as a parameter, we need to ensure that the node
# is not (a start node or a parallel join node). Start nodes will have no
# input paths and parallel join will derive input paths based on a
# formulaic approach.
if not (
node.name == "start"
or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
):
inputs.append(Parameter("input-paths"))
if any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Fetch split-index from parent
inputs.append(Parameter("split-index"))
if (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
):
# @parallel join tasks require `num-parallel` and `task-id-entropy`
# to construct the input paths, so we pass them down as input parameters.
if self.graph[node.split_parents[-1]].parallel_foreach:
inputs.extend(
[Parameter("num-parallel"), Parameter("task-id-entropy")]
)
else:
# append this only for joins of foreaches, not static splits
inputs.append(Parameter("split-cardinality"))
# check if the node is a @parallel node.
elif node.parallel_step:
inputs.extend(
[
Parameter("num-parallel"),
Parameter("task-id-entropy"),
Parameter("jobset-name"),
Parameter("workerCount"),
]
)
if any(d.name == "retry" for d in node.decorators):
inputs.append(Parameter("retryCount"))
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
if any(
self.graph[parent].matching_join
== self.graph[node.out_funcs[0]].name
for parent in self.graph[node.out_funcs[0]].split_parents
if self.graph[parent].type == "foreach"
) and any(not self.graph[f].type == "foreach" for f in node.in_funcs):
# we need to propagate the split-index and root-input-path info for
# the last step inside a foreach for correctly joining nested
# foreaches
if not any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Don't add duplicate split index parameters.
inputs.append(Parameter("split-index"))
inputs.append(Parameter("root-input-path"))
outputs = []
# @parallel steps will not have a task-id as an output parameter since task-ids
# are derived at runtime.
if not (node.name == "end" or node.parallel_step):
outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
if node.type == "foreach":
# Emit split cardinality from foreach task
outputs.append(
Parameter("num-splits").valueFrom({"path": "/mnt/out/splits"})
)
outputs.append(
Parameter("split-cardinality").valueFrom(
{"path": "/mnt/out/split_cardinality"}
)
)
if node.parallel_foreach:
outputs.extend(
[
Parameter("num-parallel").valueFrom(
{"path": "/mnt/out/num_parallel"}
),
Parameter("task-id-entropy").valueFrom(
{"path": "/mnt/out/task_id_entropy"}
),
]
)
# Outputs should be defined over here and not in the _dag_template for @parallel.
# It makes no sense to set env vars to None (shows up as "None" string)
# Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
env = {
k: v
for k, v in env.items()
if v is not None
and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
}
# Tmpfs variables
use_tmpfs = resources["use_tmpfs"]
tmpfs_size = resources["tmpfs_size"]
tmpfs_path = resources["tmpfs_path"]
tmpfs_tempdir = resources["tmpfs_tempdir"]
# Set shared_memory to 0 if it isn't specified. This results
# in Kubernetes using it's default value when the pod is created.
shared_memory = resources.get("shared_memory", 0)
port = resources.get("port", None)
if port:
port = int(port)
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
if tmpfs_enabled and tmpfs_tempdir:
env["METAFLOW_TEMPDIR"] = tmpfs_path
qos_requests, qos_limits = qos_requests_and_limits(
resources["qos"],
resources["cpu"],
resources["memory"],
resources["disk"],
)
security_context = resources.get("security_context", None)
_security_context = {}
if security_context is not None and len(security_context) > 0:
_security_context = {
"security_context": kubernetes_sdk.V1SecurityContext(
**security_context
)
}
# Create a ContainerTemplate for this node. Ideally, we would have
# liked to inline this ContainerTemplate and avoid scanning the workflow
# twice, but due to issues with variable substitution, we will have to
# live with this routine.
if node.parallel_step:
jobset_name = "{{inputs.parameters.jobset-name}}"
jobset = KubernetesArgoJobSet(
kubernetes_sdk=kubernetes_sdk,
name=jobset_name,
flow_name=self.flow.name,
run_id=run_id,
step_name=self._sanitize(node.name),
task_id=task_id,
attempt=retry_count,
user=self.username,
subdomain=jobset_name,
command=cmds,
namespace=resources["namespace"],
image=resources["image"],
image_pull_policy=resources["image_pull_policy"],
service_account=resources["service_account"],
secrets=(
[
k
for k in (
list(
[]
if not resources.get("secrets")
else (
[resources.get("secrets")]
if isinstance(resources.get("secrets"), str)
else resources.get("secrets")
)
)
+ KUBERNETES_SECRETS.split(",")
+ ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
)
if k
]
),
node_selector=resources.get("node_selector"),
cpu=str(resources["cpu"]),
memory=str(resources["memory"]),
disk=str(resources["disk"]),
gpu=resources["gpu"],
gpu_vendor=str(resources["gpu_vendor"]),
tolerations=resources["tolerations"],
use_tmpfs=use_tmpfs,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
timeout_in_seconds=run_time_limit,
persistent_volume_claims=resources["persistent_volume_claims"],
shared_memory=shared_memory,
port=port,
qos=resources["qos"],
security_context=security_context,
)
for k, v in env.items():
jobset.environment_variable(k, v)
# Set labels. Do not allow user-specified task labels to override internal ones.
#
# Explicitly add the task-id-hint label. This is important because this label
# is returned as an Output parameter of this step and is used subsequently as an
# an input in the join step.
kubernetes_labels = {
"task_id_entropy": "{{inputs.parameters.task-id-entropy}}",
"num_parallel": "{{inputs.parameters.num-parallel}}",
"metaflow/argo-workflows-name": "{{workflow.name}}",
"workflows.argoproj.io/workflow": "{{workflow.name}}",
}
jobset.labels(
{
**resources["labels"],
**self._base_labels,
**kubernetes_labels,
}
)
jobset.environment_variable(
"MF_MASTER_ADDR", jobset.jobset_control_addr
)
jobset.environment_variable("MF_MASTER_PORT", str(port))
jobset.environment_variable(
"MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
)
# We need this task-id set so that all the nodes are aware of the control
# task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
jobset.environment_variable(
"MF_PARALLEL_CONTROL_TASK_ID",
"control-{{inputs.parameters.task-id-entropy}}-0",
)
# for k, v in .items():
jobset.environment_variables_from_selectors(
{
"MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
"JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
"METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
"TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
}
)
# Set annotations. Do not allow user-specified task-specific annotations to override internal ones.
annotations = {
# setting annotations explicitly as they wont be
# passed down from WorkflowTemplate level
"metaflow/step_name": node.name,
"metaflow/attempt": str(retry_count),
"metaflow/run_id": run_id,
}
jobset.annotations(
{
**resources["annotations"],
**self._base_annotations,
**annotations,
}
)
jobset.control.replicas(1)
jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
# `TASK_ID_PREFIX` needs to explicitly be `control` or `worker`
# because the join task uses a formulaic approach to infer the task-ids
jobset.control.environment_variable("TASK_ID_PREFIX", "control")
jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")
yield (
Template(ArgoWorkflows._sanitize(node.name))
.resource(
"create",
jobset.dump(),
"status.terminalState == Completed",
"status.terminalState == Failed",
)
.inputs(Inputs().parameters(inputs))
.outputs(
Outputs().parameters(
[
Parameter("task-id-entropy").valueFrom(
{"jsonPath": "{.metadata.labels.task_id_entropy}"}
),
Parameter("num-parallel").valueFrom(
{"jsonPath": "{.metadata.labels.num_parallel}"}
),
]
)
)
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
)
)
else:
yield (
Template(self._sanitize(node.name))
# Set @timeout values
.active_deadline_seconds(run_time_limit)
# Set service account
.service_account_name(resources["service_account"])
# Configure template input
.inputs(Inputs().parameters(inputs))
# Configure template output
.outputs(Outputs().parameters(outputs))
# Fail fast!
.fail_fast()
# Set @retry/@catch values
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
)
.metadata(
ObjectMeta()
.annotation("metaflow/step_name", node.name)
# Unfortunately, we can't set the task_id since it is generated
# inside the pod. However, it can be inferred from the annotation
# set by argo-workflows - `workflows.argoproj.io/outputs` - refer
# the field 'task-id' in 'parameters'
# .annotation("metaflow/task_id", ...)
.annotation("metaflow/attempt", retry_count)
.annotations(resources["annotations"])
.labels(resources["labels"])
)
# Set emptyDir volume for state management
.empty_dir_volume("out")
# Set tmpfs emptyDir volume if enabled
.empty_dir_volume(
"tmpfs-ephemeral-volume",
medium="Memory",
size_limit=tmpfs_size if tmpfs_enabled else 0,
)
.empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
.pvc_volumes(resources.get("persistent_volume_claims"))
# Set node selectors
.node_selectors(resources.get("node_selector"))
# Set tolerations
.tolerations(resources.get("tolerations"))
# Set container
.container(
# TODO: Unify the logic with kubernetes.py
# Important note - Unfortunately, V1Container uses snakecase while
# Argo Workflows uses camel. For most of the attributes, both cases
# are indistinguishable, but unfortunately, not for all - (
# env_from, value_from, etc.) - so we need to handle the conversion
# ourselves using to_camelcase. We need to be vigilant about
# resources attributes in particular where the keys maybe user
# defined.
to_camelcase(
kubernetes_sdk.V1Container(
name=self._sanitize(node.name),
command=cmds,
termination_message_policy="FallbackToLogsOnError",
ports=(
[
kubernetes_sdk.V1ContainerPort(
container_port=port
)
]
if port
else None
),
env=[
kubernetes_sdk.V1EnvVar(name=k, value=str(v))
for k, v in env.items()
]
# Add environment variables for book-keeping.
# https://argoproj.github.io/argo-workflows/fields/#fields_155
+ [
kubernetes_sdk.V1EnvVar(
name=k,
value_from=kubernetes_sdk.V1EnvVarSource(
field_ref=kubernetes_sdk.V1ObjectFieldSelector(
field_path=str(v)
)
),
)
for k, v in {
"METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
}.items()
],
image=resources["image"],
image_pull_policy=resources["image_pull_policy"],
resources=kubernetes_sdk.V1ResourceRequirements(
requests=qos_requests,
limits={
**qos_limits,
**{
"%s.com/gpu".lower()
% resources["gpu_vendor"]: str(
resources["gpu"]
)
for k in [0]
if resources["gpu"] is not None
},
},
),
# Configure secrets
env_from=[
kubernetes_sdk.V1EnvFromSource(
secret_ref=kubernetes_sdk.V1SecretEnvSource(
name=str(k),
# optional=True
)
)
for k in list(
[]
if not resources.get("secrets")
else (
[resources.get("secrets")]
if isinstance(resources.get("secrets"), str)
else resources.get("secrets")
)
)
+ KUBERNETES_SECRETS.split(",")
+ ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
if k
],
volume_mounts=[
# Assign a volume mount to pass state to the next task.
kubernetes_sdk.V1VolumeMount(
name="out", mount_path="/mnt/out"
)
]
# Support tmpfs.
+ (
[
kubernetes_sdk.V1VolumeMount(
name="tmpfs-ephemeral-volume",
mount_path=tmpfs_path,
)
]
if tmpfs_enabled
else []
)
# Support shared_memory
+ (
[
kubernetes_sdk.V1VolumeMount(
name="dhsm",
mount_path="/dev/shm",
)
]
if shared_memory
else []
)
# Support persistent volume claims.
+ (
[
kubernetes_sdk.V1VolumeMount(
name=claim, mount_path=path
)
for claim, path in resources.get(
"persistent_volume_claims"
).items()
]
if resources.get("persistent_volume_claims")
is not None
else []
),
**_security_context,
).to_dict()
)
)
)