xlml/utils/xpk.py (231 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to run workloads with xpk (https://github.com/AI-Hypercomputer/xpk).""" import os import tempfile import uuid from absl import logging from airflow.decorators import task from airflow.exceptions import AirflowFailException from airflow.hooks.subprocess import SubprocessHook from kubernetes import client as k8s_client from xlml.apis import metric_config from xlml.utils import gke from dags.common.vm_resource import GpuVersion # b/411426745 - Setting branch to 0.4.1 till the depdency issue is resolved. MAIN_BRANCH = "v0.4.1" # Duration = past 7 days LOGGING_URL_FORMAT = ( "https://pantheon.corp.google.com/logs/query;" + "query=resource.type%3D%22k8s_container%22%0A" + "resource.labels.project_id%3D%22{project}%22%0A" + "resource.labels.location%3D%22{region}%22%0A" + "resource.labels.cluster_name%3D%22{cluster}%22%0A" + "resource.labels.namespace_name%3D%22default%22%0A" + "labels.k8s-pod%2Fjobset_sigs_k8s_io%2F" + "jobset-name%3D%22{workload_id}%22%20severity%3E%3DDEFAULT;" + "storageScope=project;duration=P7D?e=13803378&" + "mods=allow_workbench_image_override&project={project}" ) def get_xpk_setup_cmd(tmpdir, branch: str = MAIN_BRANCH): clone_branch = ( f"git clone --branch {branch} https://github.com/AI-Hypercomputer/xpk" f" {tmpdir}/xpk" ) cmds = [ "set -xue", clone_branch, "pip install ruamel.yaml docker", ] return cmds def is_valid_gpu_version(accelerator_type: str): if accelerator_type in [member.value for member in GpuVersion]: return True return False @task def generate_workload_id(benchmark_id: str) -> str: """Generate a valid workload ID.""" import re short_id = str(uuid.uuid4())[:8] # Remove all non-alphanumeric characters, and truncate to ensure the result # is less than 40 characters. short_benchmark = re.sub(r"[^a-zA-Z0-9-]+", "", benchmark_id)[:32] return f"{short_benchmark}{short_id}" @task def run_workload( task_id: str, cluster_project: str, zone: str, cluster_name: str, benchmark_id: str, workload_id: str, gcs_path: str, docker_image: str, accelerator_type: str, run_cmds: str, num_slices: int = 1, use_vertex_tensorboard: bool = False, use_pathways: bool = False, ramdisk_directory: str = "", # Directory for enabling emergency checkpointing mtc_enabled: bool = False, # It enables MTC phase-2 drivers xpk_branch: str = MAIN_BRANCH, ): """Run workload through xpk tool.""" with tempfile.TemporaryDirectory() as tmpdir: if accelerator_type in [ GpuVersion.XPK_H100.value, GpuVersion.XPK_H100_MEGA.value, ]: multi_keyword = "num-nodes" else: multi_keyword = "num-slices" create_field = "create-pathways" if use_pathways else "create" type_field = "tpu-type" if use_pathways else "device-type" workload_create_cmd = ( f"python {tmpdir}/xpk/xpk.py workload {create_field}" f" --cluster={cluster_name} --workload={workload_id}" f" --command='{run_cmds}' --{type_field}={accelerator_type}" f" --{multi_keyword}={num_slices} --docker-image={docker_image}" f" --project={cluster_project} --zone={zone}" f" --env {metric_config.SshEnvVars.GCS_OUTPUT.name}={gcs_path}" " --restart-on-user-code-failure" ) if ramdisk_directory: workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}" if mtc_enabled: workload_create_cmd += " --mtc-enabled" # If using a valid GPU and the XPK branch is set to "main", then branch is switch to "v0.4.1". if is_valid_gpu_version(accelerator_type) and xpk_branch == MAIN_BRANCH: xpk_branch = "v0.4.1" cmds = get_xpk_setup_cmd(tmpdir, xpk_branch) if accelerator_type == GpuVersion.XPK_H100_MEGA.value: workload_create_cmd += " --scheduler=gke.io/topology-aware-auto" if use_vertex_tensorboard: workload_create_cmd += " --use-vertex-tensorboard" vertex_ai_dependency = ( "pip install -U google-cloud-aiplatform cloud-accelerator-diagnostics" ) cmds.append(vertex_ai_dependency) cmds.append(workload_create_cmd) hook = SubprocessHook() result = hook.run_command( ["bash", "-c", ";".join(cmds)], env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")}, ) assert ( result.exit_code == 0 ), f"XPK command failed with code {result.exit_code}" def _get_core_api_client( project_id: str, region: str, cluster_name: str ) -> k8s_client.CoreV1Api: """Create a core API client for the given cluster.""" client = gke.get_authenticated_client(project_id, region, cluster_name) # Initilize the client core_api = k8s_client.CoreV1Api(client) logging.info("Successful initilize k8s client from cluster response.") return core_api def _list_workload_pods( core_api: k8s_client.CoreV1Api, workload_id: str ) -> k8s_client.V1PodList: """List all pods for the given workload.""" logging.info(f"Getting pods for workload_id: {workload_id}") pods = core_api.list_namespaced_pod( label_selector=f"jobset.sigs.k8s.io/jobset-name={workload_id}", namespace="default", ) return pods def _get_batch_api_client( project_id: str, region: str, cluster_name: str ) -> k8s_client.BatchV1Api: """Create a batch API client for the given cluster.""" client = gke.get_authenticated_client(project_id, region, cluster_name) # Initilize the client batch_api = k8s_client.BatchV1Api(client) logging.info( "Successful initilize k8s batch api client from cluster response." ) return batch_api def _get_workload_job( batch_api: k8s_client.BatchV1Api, workload_id: str ) -> k8s_client.V1Job: """Get the job for a given workload.""" logging.info(f"Getting job for workload_id: {workload_id}") jobs = batch_api.list_namespaced_job( label_selector=f"jobset.sigs.k8s.io/jobset-name={workload_id}", namespace="default", ) if len(jobs.items) == 0: logging.info(f"Getting job for workload_id: {workload_id}") return None if len(jobs.items) > 1: logging.info(f"Got more than one job for workload_id: {workload_id}") for i, job in enumerate(jobs.items): logging.info(f"Job {i=}") logging.info(f"{job}") return jobs.items[0] @task.sensor(poke_interval=60, timeout=600, mode="reschedule") def wait_for_workload_start( workload_id: str, project_id: str, region: str, cluster_name: str ) -> bool: """Check if the workload has started.""" core_api = _get_core_api_client(project_id, region, cluster_name) pods = _list_workload_pods(core_api, workload_id) print(f"Found {len(pods.items)} pods for workload {workload_id}") return len(pods.items) > 0 @task.sensor(poke_interval=60, timeout=600, mode="reschedule") def wait_for_workload_completion( workload_id: str, project_id: str, region: str, cluster_name: str ) -> bool: """Check the workload status.""" core_api = _get_core_api_client(project_id, region, cluster_name) pods = _list_workload_pods(core_api, workload_id) if not pods.items: logging.info(f"No pods found for workload selector: {workload_id}.") # Pathways jobs delete all pods on failure so we must also check if the job # is complete batch_api = _get_batch_api_client(project_id, region, cluster_name) job = _get_workload_job(batch_api, workload_id) if job is None: logging.info( f"No pods or jobs were found for workload selector: {workload_id}" ) return False if any(condition.type == "Failed" for condition in job.status.conditions): # Don't keep retrying if the job has failed raise AirflowFailException('Job has condition type: "Failed"') if any(condition.type == "Complete" for condition in job.status.conditions): logging.info( "No pods found but job is complete for workload selector:" f" {workload_id}" ) return True return False if any(pod.status.phase in ["Pending", "Running"] for pod in pods.items): logging.info("At least one pod has yet to complete.") return False try: for pod in pods.items: if pod.status.phase == "Failed": # Don't keep retrying if the pod has failed raise AirflowFailException(f"Bad pod phase: {pod.status.phase}") elif pod.status.phase in ["Unknown"]: raise RuntimeError(f"Bad pod phase: {pod.status.phase}") finally: # TODO(jonbolin): log printing for GPUs, which have multiple containers if len(pod.spec.containers) == 1: # Print the logs of the last pod checked - either the first failed pod or # the last successful one. logs = core_api.read_namespaced_pod_log( name=pod.metadata.name, namespace=pod.metadata.namespace ) logging.info(f"Logs for pod {pod.metadata.name}:") for line in logs.split("\n"): logging.info(line) url = LOGGING_URL_FORMAT.format( project=project_id, region=region, cluster=cluster_name, workload_id=workload_id, ) logging.info(f"Link to logs: {url}") logging.info("All pod(s) phase are succeeded.") return True @task(trigger_rule="all_done") def clean_up_workload( workload_id: str, project_id: str, zone: str, cluster_name: str, xpk_branch: str = MAIN_BRANCH, ) -> bool: """Delete workload.""" with tempfile.TemporaryDirectory() as tmpdir: workload_delete_cmd = ( f"python {tmpdir}/xpk/xpk.py workload delete" f" --cluster={cluster_name} --workload={workload_id}" f" --project={project_id} --zone={zone}" ) cmds = get_xpk_setup_cmd(tmpdir, xpk_branch) cmds.append(workload_delete_cmd) hook = SubprocessHook() result = hook.run_command( ["bash", "-c", ";".join(cmds)], env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")}, ) assert ( result.exit_code == 0 ), f"XPK clean-up failed with code {result.exit_code}"