in metaflow/plugins/kubernetes/kubernetes_job.py [0:0]
def create_job_spec(self):
client = self._client.get()
# tmpfs variables
use_tmpfs = self._kwargs["use_tmpfs"]
tmpfs_size = self._kwargs["tmpfs_size"]
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
shared_memory = (
int(self._kwargs["shared_memory"])
if self._kwargs["shared_memory"]
else None
)
qos_requests, qos_limits = qos_requests_and_limits(
self._kwargs["qos"],
self._kwargs["cpu"],
self._kwargs["memory"],
self._kwargs["disk"],
)
security_context = self._kwargs.get("security_context", {})
_security_context = {}
if security_context is not None and len(security_context) > 0:
_security_context = {
"security_context": client.V1SecurityContext(**security_context)
}
return client.V1JobSpec(
# Retries are handled by Metaflow when it is responsible for
# executing the flow. The responsibility is moved to Kubernetes
# when Argo Workflows is responsible for the execution.
backoff_limit=self._kwargs.get("retries", 0),
completions=self._kwargs.get("completions", 1),
ttl_seconds_after_finished=7
* 60
* 60 # Remove job after a week. TODO: Make this configurable
* 24,
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
annotations=self._kwargs.get("annotations", {}),
labels=self._kwargs.get("labels", {}),
namespace=self._kwargs["namespace"],
),
spec=client.V1PodSpec(
# Timeout is set on the pod and not the job (important!)
active_deadline_seconds=self._kwargs["timeout_in_seconds"],
# TODO (savin): Enable affinities for GPU scheduling.
# affinity=?,
containers=[
client.V1Container(
command=self._kwargs["command"],
termination_message_policy="FallbackToLogsOnError",
ports=(
[]
if self._kwargs["port"] is None
else [
client.V1ContainerPort(
container_port=int(self._kwargs["port"])
)
]
),
env=[
client.V1EnvVar(name=k, value=str(v))
for k, v in self._kwargs.get(
"environment_variables", {}
).items()
]
# And some downward API magic. Add (key, value)
# pairs below to make pod metadata available
# within Kubernetes container.
+ [
client.V1EnvVar(
name=k,
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=str(v)
)
),
)
for k, v in {
"METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
}.items()
]
+ [
client.V1EnvVar(name=k, value=str(v))
for k, v in inject_tracing_vars({}).items()
],
env_from=[
client.V1EnvFromSource(
secret_ref=client.V1SecretEnvSource(
name=str(k),
# optional=True
)
)
for k in list(self._kwargs.get("secrets", []))
+ KUBERNETES_SECRETS.split(",")
if k
],
image=self._kwargs["image"],
image_pull_policy=self._kwargs["image_pull_policy"],
name=self._kwargs["step_name"].replace("_", "-"),
resources=client.V1ResourceRequirements(
requests=qos_requests,
limits={
**qos_limits,
**{
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
},
},
),
volume_mounts=(
[
client.V1VolumeMount(
mount_path=self._kwargs.get("tmpfs_path"),
name="tmpfs-ephemeral-volume",
)
]
if tmpfs_enabled
else []
)
+ (
[
client.V1VolumeMount(
mount_path="/dev/shm", name="dhsm"
)
]
if shared_memory
else []
)
+ (
[
client.V1VolumeMount(mount_path=path, name=claim)
for claim, path in self._kwargs[
"persistent_volume_claims"
].items()
]
if self._kwargs["persistent_volume_claims"] is not None
else []
),
**_security_context,
)
],
node_selector=self._kwargs.get("node_selector"),
# TODO (savin): Support image_pull_secrets
# image_pull_secrets=?,
# TODO (savin): Support preemption policies
# preemption_policy=?,
#
# A Container in a Pod may fail for a number of
# reasons, such as because the process in it exited
# with a non-zero exit code, or the Container was
# killed due to OOM etc. If this happens, fail the pod
# and let Metaflow handle the retries.
restart_policy="Never",
service_account_name=self._kwargs["service_account"],
# Terminate the container immediately on SIGTERM
termination_grace_period_seconds=0,
tolerations=[
client.V1Toleration(**toleration)
for toleration in self._kwargs.get("tolerations") or []
],
volumes=(
[
client.V1Volume(
name="tmpfs-ephemeral-volume",
empty_dir=client.V1EmptyDirVolumeSource(
medium="Memory",
# Add default unit as ours differs from Kubernetes default.
size_limit="{}Mi".format(tmpfs_size),
),
)
]
if tmpfs_enabled
else []
)
+ (
[
client.V1Volume(
name="dhsm",
empty_dir=client.V1EmptyDirVolumeSource(
medium="Memory",
size_limit="{}Mi".format(shared_memory),
),
)
]
if shared_memory
else []
)
+ (
[
client.V1Volume(
name=claim,
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name=claim
),
)
for claim in self._kwargs["persistent_volume_claims"].keys()
]
if self._kwargs["persistent_volume_claims"] is not None
else []
),
),
),
)