xlml/utils/gke.py (146 lines of code) (raw):

import base64 import concurrent.futures import datetime import logging import tempfile import time from typing import Any, Dict, Optional from airflow.decorators import task, task_group import google.auth import google.auth.transport.requests from google.cloud import container_v1 import kubernetes from xlml.apis import gcp_config """Utilities for GKE.""" class PodsNotReadyError(Exception): """Exception raised when pods are not ready within the expected timeout.""" def __init__(self, message): super().__init__(message) def get_authenticated_client( project_name: str, region: str, cluster_name: str ) -> kubernetes.client.ApiClient: container_client = container_v1.ClusterManagerClient() cluster_path = ( f'projects/{project_name}/locations/{region}/clusters/{cluster_name}' ) response = container_client.get_cluster(name=cluster_path) creds, _ = google.auth.default() auth_req = google.auth.transport.requests.Request() creds.refresh(auth_req) configuration = kubernetes.client.Configuration() configuration.host = f'https://{response.endpoint}' ca_cert_content = base64.b64decode( response.master_auth.cluster_ca_certificate ) with tempfile.NamedTemporaryFile(delete=False) as ca_cert: ca_cert.write(ca_cert_content) configuration.ssl_ca_cert = ca_cert.name configuration.api_key_prefix['authorization'] = 'Bearer' configuration.api_key['authorization'] = creds.token return kubernetes.client.ApiClient(configuration) @task_group def run_job( body: Dict[str, Any], gcp: gcp_config.GCPConfig, cluster_name: str, job_create_timeout: datetime.timedelta, gcs_location: str = '', ): """Run a batch job directly on a GKE cluster. Args: body: Dict that defines a Kubernetes `Job`. gcp: GCP config with the project name and zone of the GKE cluster. cluster_name: Name of the GCP cluster. job_create_timeout: Amount of time to wait for all pods to become active. """ @task def deploy_job(gcs_location): body['spec']['template']['spec']['containers'][0]['env'].append( {'name': 'GCS_OUTPUT', 'value': gcs_location} ) client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name) jobs_client = kubernetes.client.BatchV1Api(client) resp = jobs_client.create_namespaced_job(namespace='default', body=body) logging.info(f'response: {resp}') return resp.metadata.name @task.sensor( poke_interval=60, timeout=job_create_timeout.total_seconds(), mode='reschedule', ) def wait_all_pods_ready(name: str): client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name) batch_api = kubernetes.client.BatchV1Api(client) job = batch_api.read_namespaced_job(namespace='default', name=name) # TODO(wcromar): Handle other conditions (e.g. unschedulablility) logging.info(f'Job status: {job.status}') if job.status.failed: raise RuntimeError(f'Job has {job.status.failed} failed pods.') core_api = kubernetes.client.CoreV1Api(client) pod_label_selector = f'batch.kubernetes.io/job-name={name}' pods = core_api.list_namespaced_pod( namespace='default', label_selector=pod_label_selector ) if len(pods.items) != body['spec']['parallelism']: logging.info('Waiting for all pods to be created...') return False return True @task(retries=6) def stream_logs(name: str): def _watch_pod(name, namespace) -> Optional[int]: logs_watcher = kubernetes.watch.Watch() logging.info(f'Waiting for pod {name} to start...') pod_watcher = kubernetes.watch.Watch() for event in pod_watcher.stream( core_api.list_namespaced_pod, namespace, field_selector=f'metadata.name={name}', ): status = event['object'].status logging.info( f'Pod {event["object"].metadata.name} status: {status.phase}' ) if status.phase != 'Pending': break logging.info(f'Streaming pod logs for {name}...') for line in logs_watcher.stream( core_api.read_namespaced_pod_log, name, namespace, _request_timeout=3600, ): logging.info(f'{name}] {line}') logging.warning(f'Lost logs stream for {name}.') pod = core_api.read_namespaced_pod(namespace='default', name=name) if pod.status.container_statuses: container_status = pod.status.container_statuses[0] if pod.status.container_statuses[0].state.terminated: exit_code = container_status.state.terminated.exit_code if exit_code: logging.error(f'Pod {name} had non-zero exit code {exit_code}') return exit_code logging.warning(f'Unknown status for pod {name}') return None # We need to re-authenticate if the stream_logs fail. This can happen when # the job runs for too long and the credential expire. client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name) batch_api = kubernetes.client.BatchV1Api(client) core_api = kubernetes.client.CoreV1Api(client) pod_label_selector = f'batch.kubernetes.io/job-name={name}' pods = core_api.list_namespaced_pod( namespace='default', label_selector=pod_label_selector ) # TODO(piz): Use time.sleep may not be a good solution here. However, I expect # resources are all ready in wait_all_pods_ready stage. This just in case # authentication takes time. Check with Will for better solutions. time.sleep(30) if len(pods.items) != body['spec']['parallelism']: logging.info('Waiting for all pods to be re-connected...') raise PodsNotReadyError('pods are not ready after refreshing credential.') with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for pod in pods.items: f = executor.submit( _watch_pod, pod.metadata.name, pod.metadata.namespace ) futures.append(f) # Wait for pods to complete, and exit with the first non-zero exit code. for f in concurrent.futures.as_completed(futures): try: # TODO(piz/wcromar): it looks like there is a delay between as_completed # and update of f.result(). exit_code can be None even task is complete. exit_code = f.result() except kubernetes.client.ApiException as e: logging.error('Kubernetes error. Retrying...', exc_info=e) exit_code = None # Retry if status is unknown if exit_code is None: raise RuntimeError('unknown exit code') if exit_code: raise RuntimeError('Non-zero exit code') name = deploy_job(gcs_location) wait_all_pods_ready(name) >> stream_logs(name) def zone_to_region(zone: str) -> str: zone_terms = zone.split('-') return zone_terms[0] + '-' + zone_terms[1]