def create_job_spec()

in metaflow/plugins/kubernetes/kubernetes_job.py [0:0]


    def create_job_spec(self):
        client = self._client.get()

        # tmpfs variables
        use_tmpfs = self._kwargs["use_tmpfs"]
        tmpfs_size = self._kwargs["tmpfs_size"]
        tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
        shared_memory = (
            int(self._kwargs["shared_memory"])
            if self._kwargs["shared_memory"]
            else None
        )
        qos_requests, qos_limits = qos_requests_and_limits(
            self._kwargs["qos"],
            self._kwargs["cpu"],
            self._kwargs["memory"],
            self._kwargs["disk"],
        )

        security_context = self._kwargs.get("security_context", {})
        _security_context = {}
        if security_context is not None and len(security_context) > 0:
            _security_context = {
                "security_context": client.V1SecurityContext(**security_context)
            }

        return client.V1JobSpec(
            # Retries are handled by Metaflow when it is responsible for
            # executing the flow. The responsibility is moved to Kubernetes
            # when Argo Workflows is responsible for the execution.
            backoff_limit=self._kwargs.get("retries", 0),
            completions=self._kwargs.get("completions", 1),
            ttl_seconds_after_finished=7
            * 60
            * 60  # Remove job after a week. TODO: Make this configurable
            * 24,
            template=client.V1PodTemplateSpec(
                metadata=client.V1ObjectMeta(
                    annotations=self._kwargs.get("annotations", {}),
                    labels=self._kwargs.get("labels", {}),
                    namespace=self._kwargs["namespace"],
                ),
                spec=client.V1PodSpec(
                    # Timeout is set on the pod and not the job (important!)
                    active_deadline_seconds=self._kwargs["timeout_in_seconds"],
                    # TODO (savin): Enable affinities for GPU scheduling.
                    # affinity=?,
                    containers=[
                        client.V1Container(
                            command=self._kwargs["command"],
                            termination_message_policy="FallbackToLogsOnError",
                            ports=(
                                []
                                if self._kwargs["port"] is None
                                else [
                                    client.V1ContainerPort(
                                        container_port=int(self._kwargs["port"])
                                    )
                                ]
                            ),
                            env=[
                                client.V1EnvVar(name=k, value=str(v))
                                for k, v in self._kwargs.get(
                                    "environment_variables", {}
                                ).items()
                            ]
                            # And some downward API magic. Add (key, value)
                            # pairs below to make pod metadata available
                            # within Kubernetes container.
                            + [
                                client.V1EnvVar(
                                    name=k,
                                    value_from=client.V1EnvVarSource(
                                        field_ref=client.V1ObjectFieldSelector(
                                            field_path=str(v)
                                        )
                                    ),
                                )
                                for k, v in {
                                    "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
                                    "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
                                    "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
                                    "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
                                    "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
                                    "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
                                }.items()
                            ]
                            + [
                                client.V1EnvVar(name=k, value=str(v))
                                for k, v in inject_tracing_vars({}).items()
                            ],
                            env_from=[
                                client.V1EnvFromSource(
                                    secret_ref=client.V1SecretEnvSource(
                                        name=str(k),
                                        # optional=True
                                    )
                                )
                                for k in list(self._kwargs.get("secrets", []))
                                + KUBERNETES_SECRETS.split(",")
                                if k
                            ],
                            image=self._kwargs["image"],
                            image_pull_policy=self._kwargs["image_pull_policy"],
                            name=self._kwargs["step_name"].replace("_", "-"),
                            resources=client.V1ResourceRequirements(
                                requests=qos_requests,
                                limits={
                                    **qos_limits,
                                    **{
                                        "%s.com/gpu".lower()
                                        % self._kwargs["gpu_vendor"]: str(
                                            self._kwargs["gpu"]
                                        )
                                        for k in [0]
                                        # Don't set GPU limits if gpu isn't specified.
                                        if self._kwargs["gpu"] is not None
                                    },
                                },
                            ),
                            volume_mounts=(
                                [
                                    client.V1VolumeMount(
                                        mount_path=self._kwargs.get("tmpfs_path"),
                                        name="tmpfs-ephemeral-volume",
                                    )
                                ]
                                if tmpfs_enabled
                                else []
                            )
                            + (
                                [
                                    client.V1VolumeMount(
                                        mount_path="/dev/shm", name="dhsm"
                                    )
                                ]
                                if shared_memory
                                else []
                            )
                            + (
                                [
                                    client.V1VolumeMount(mount_path=path, name=claim)
                                    for claim, path in self._kwargs[
                                        "persistent_volume_claims"
                                    ].items()
                                ]
                                if self._kwargs["persistent_volume_claims"] is not None
                                else []
                            ),
                            **_security_context,
                        )
                    ],
                    node_selector=self._kwargs.get("node_selector"),
                    # TODO (savin): Support image_pull_secrets
                    # image_pull_secrets=?,
                    # TODO (savin): Support preemption policies
                    # preemption_policy=?,
                    #
                    # A Container in a Pod may fail for a number of
                    # reasons, such as because the process in it exited
                    # with a non-zero exit code, or the Container was
                    # killed due to OOM etc. If this happens, fail the pod
                    # and let Metaflow handle the retries.
                    restart_policy="Never",
                    service_account_name=self._kwargs["service_account"],
                    # Terminate the container immediately on SIGTERM
                    termination_grace_period_seconds=0,
                    tolerations=[
                        client.V1Toleration(**toleration)
                        for toleration in self._kwargs.get("tolerations") or []
                    ],
                    volumes=(
                        [
                            client.V1Volume(
                                name="tmpfs-ephemeral-volume",
                                empty_dir=client.V1EmptyDirVolumeSource(
                                    medium="Memory",
                                    # Add default unit as ours differs from Kubernetes default.
                                    size_limit="{}Mi".format(tmpfs_size),
                                ),
                            )
                        ]
                        if tmpfs_enabled
                        else []
                    )
                    + (
                        [
                            client.V1Volume(
                                name="dhsm",
                                empty_dir=client.V1EmptyDirVolumeSource(
                                    medium="Memory",
                                    size_limit="{}Mi".format(shared_memory),
                                ),
                            )
                        ]
                        if shared_memory
                        else []
                    )
                    + (
                        [
                            client.V1Volume(
                                name=claim,
                                persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
                                    claim_name=claim
                                ),
                            )
                            for claim in self._kwargs["persistent_volume_claims"].keys()
                        ]
                        if self._kwargs["persistent_volume_claims"] is not None
                        else []
                    ),
                ),
            ),
        )