cli/gke_check.py (384 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A GKE implementation of the healthscan check interface.""" import math import subprocess import sys import time from typing import Any import click from kubernetes import client from kubernetes import config import check import launch_helm class GkeCheck(check.Check): """A standard implementation of a healthscan check.""" def sigint_handler( self, signum: Any, frame: Any, ) -> None: """Handler for SIGINT signal. Args: signum: The signal number. frame: The current stack frame. """ print(f'Received {signum} signal on frame {frame}. Exiting...') # Perform any necessary cleanup actions here # For example: close file handlers, release resources, etc. click.echo( click.style( '\nCLEANING UP...', fg='red', bold=True, ) ) self.clean_up() # Stops proceeding anything the CLI is doing sys.exit(0) def __init__( self, name: str, description: str, machine_type: str, supported_machine_types: frozenset[str], nodes: list[str], results_labels: list[str] | None, launch_label: str | None, launch_label_value: str = 'true', run_only_on_available_nodes: bool = False, kubectl_core_api: client.CoreV1Api | None = None, timeout_sec: int = 15 * 60, dry_run: bool = False, container_name: str | None = None, ): """Initialize a check to run on a GKE cluster. Args: name: The name of the check. description: The description of the check. machine_type: The machine type of the cluster to run the check on. supported_machine_types: The machine types supported by the check. nodes: The nodes to run the check on. results_labels: The labels to use for the results. launch_label: The label to use for the launch. launch_label_value: The value to use for the launch label. run_only_on_available_nodes: Whether to run the check only on available nodes. kubectl_core_api: The kubectl core api to use for the check. timeout_sec: The timeout in seconds for the check. dry_run: Whether to run the check in dry run mode. container_name: The name of the main container used by the check. """ super().__init__( name=name, description=description, machine_type=machine_type, supported_machine_types=supported_machine_types, dry_run=dry_run, ) self.results_labels = results_labels self.nodes = nodes self.launch_label = launch_label self.launch_label_value = launch_label_value self.run_only_on_available_nodes = run_only_on_available_nodes self.timeout_sec = timeout_sec self.check_logs = None self.check_container_name = container_name self.hr_release_name: str = f'chs-hr-{self.name}-cli' # Generate a unique base name for the HC Helm release # Default to HC release w/ no special base name # Possible example: chs-hc-gpu-cli-12345678-1723456789 self.hc_release_name_base: str = f'chs-hc-{self.name}-cli' if not dry_run: # Used to interface with the GKE cluster if kubectl_core_api: self._v1 = kubectl_core_api else: config.load_kube_config() self._v1 = client.CoreV1Api() def _get_occupied_nodes(self) -> set[str]: """Gets all requested nodes that are currently occupied. Returns: A list of occupied nodes. Optionally, if nodes is provided, only the occupied nodes of those provided will be returned. """ occupied_nodes = set() try: pods = self._v1.list_pod_for_all_namespaces( watch=False, field_selector='status.phase=Running' ) for pod in pods.items: for container in pod.spec.containers: if container.resources.requests: requested_gpus = set( resource_name for resource_name, _ in container.resources.requests.items() if resource_name == 'nvidia.com/gpu' ) if requested_gpus: occupied_nodes.add(pod.spec.node_name) except client.rest.ApiException as e: click.echo( click.style( f'Failed to list nodes in cluster: {e}', fg='red', bold=True ) ) return ( occupied_nodes if not self.nodes else set(self.nodes).intersection(occupied_nodes) ) def _get_nodes_with_machine_type(self) -> list[str]: """Returns the names of all nodes with the given machine type.""" return [ node.metadata.name for node in self._v1.list_node( label_selector=( f'node.kubernetes.io/instance-type={self.machine_type}' ) ).items ] def _has_machine_type_on_cluster(self) -> bool: """Returns all nodes with the given machine type.""" return bool(len(self._get_nodes_with_machine_type())) def set_up(self): """Set up for the check on a GKE cluster.""" if self.dry_run: click.echo( click.style( 'Dry run mode enabled. Skipping set_up.', fg='red', bold=True, ) ) return if not self._has_machine_type_on_cluster(): click.echo( click.style( f'Active cluster does not have machine type {self.machine_type}.', fg='red', bold=True, ) ) raise click.Abort() occupied_nodes = self._get_occupied_nodes() if occupied_nodes and not self.run_only_on_available_nodes: click.echo( click.style( f'The following nodes are occupied: {occupied_nodes}. Please free' ' up these nodes before running healthscan.\n' ' Alternatively, you can run again with' ' --run_only_on_available_nodes to skip these nodes.', fg='red', bold=True, ) ) raise click.Abort() elif self.run_only_on_available_nodes and not self.nodes: click.echo( click.style( 'WARNING: Running only on available nodes is not recommended.\n' 'The following nodes are occupied and will be skipped: ' f'{occupied_nodes}', fg='red', bold=True, ) ) self.nodes = [ node for node in self._get_nodes_with_machine_type() if node not in occupied_nodes ] elif self.run_only_on_available_nodes: click.echo( click.style( 'WARNING: Running only on available nodes is not recommended.\n' 'The following nodes are occupied and will be skipped: ' f'{occupied_nodes}', fg='red', bold=True, ) ) self.nodes = [node for node in self.nodes if node not in occupied_nodes] launch_helm.setup_k8s_cluster( launch_label=self.launch_label, launch_label_value=self.launch_label_value, results_labels=self.results_labels, nodes=self.nodes, ) def _get_helm_releases( self, release_name_base: str | None, ) -> list[str]: """Get all Helm releases with the given release name base. Args: release_name_base: The base name of the Helm release to filter by. Returns: A list of Helm releases with the given release name base. """ # Note this will use the default helm ls limit of 256 helm_ls_command = [ 'helm', 'ls', '-a', '--no-headers', '--filter', release_name_base, ] try: helm_ls_output = subprocess.run( helm_ls_command, text=True, check=True, capture_output=True, ) # Keep only the release name (at the beginning of each line) helm_releases = [ release_name.split('\t')[0].strip() # Separated w/ tabs for release_name in helm_ls_output.stdout.strip().split('\n') if release_name # Catch the case where the release name is empty ] # Can happen when a non-zero exit code is returned except subprocess.CalledProcessError as e: click.echo( click.style( text=f'Failed to get Helm releases:\n{e}', fg='red', bold=True, ), ) helm_releases = [] # Catch if helm is not installed except FileNotFoundError as e: click.echo( click.style( text=( 'Failed to get Helm releases (`helm` likely not installed):\n' f'{e}' ), fg='red', bold=True, ), ) helm_releases = [] return helm_releases def clean_up(self) -> None: """Clean up after the check on a GKE cluster.""" if self.dry_run: click.echo( click.style( 'Dry run mode enabled. Skipping clean_up.', fg='red', bold=True, ) ) return # Attempt to clean up all HC Helm releases not already uninstalled helm_releases = self._get_helm_releases(self.hc_release_name_base) # Iterate over each release and uninstall it for release_name in helm_releases: helm_uninstall_command = [ 'helm', 'uninstall', release_name, ] click.echo(f'Uninstalling "{release_name}"') uninstall_result = subprocess.run( helm_uninstall_command, text=True, check=False, capture_output=True, ) # Check if the overall operation was successful if uninstall_result.returncode == 0: click.echo(f'Release "{release_name}" uninstalled successfully.') else: click.echo(f'Release "{release_name}" failed to uninstall.') click.echo(f'Uninstall result: {uninstall_result.stdout.strip()}') # Other processes to clean up like HR Helm release, labels, etc. launch_helm.cleanup_k8s_cluster( hr_release_name=self.hr_release_name, launch_label=self.launch_label, nodes=self.nodes, ) return def _get_values_file(self) -> str: """Get the values file for the check.""" base_path = 'deploy/helm/health_runner/' match self.machine_type: case 'a3-highgpu-8g': return base_path + 'a3high.yaml' case 'a3-megagpu-8g': # Use the default values for A3 Mega return base_path + 'values.yaml' case 'a3-ultragpu-8g': return base_path + 'a3ultra.yaml' case 'a4-highgpu-8g': return base_path + 'a4.yaml' case _: raise ValueError(f'Unsupported machine type: {self.machine_type}') def _get_helm_env_vars(self): # Only set N_NODES if nodes are specified (otherwise uses all nodes) additional_helm_env_vars: dict[str, str] | None = None # If nodes are not specified, then the health runner will use all nodes if self.nodes: n_nodes = len(self.nodes) additional_helm_env_vars: dict[str, str] = { f'health_checks.{self.name}_healthcheck.env.N_NODES': str(n_nodes), } return additional_helm_env_vars def _check(self, sleep_sec: int = 300, dry_run: bool = False) -> str: """Run the check on a GKE cluster.""" return launch_helm.deploy_health_runner( hr_release_name=self.hr_release_name, hc_type=self.name, wait=math.floor(sleep_sec / 60), values_file=self._get_values_file(), hc_release_name_base=self.hc_release_name_base, additional_helm_env_vars=self._get_helm_env_vars(), dry_run=dry_run, ) def _get_pod_phase( self, pod_name: str, namespace: str = 'default', ) -> str: """Get the phase of the pod.""" pod_phase = self._v1.read_namespaced_pod( name=pod_name, namespace=namespace, ).status.phase return pod_phase def _progress_bar_item_show( self, pod_name: str | None, ) -> str | None: """Get the progress bar item to show.""" if pod_name is None: return None else: pod_status = ( f'Health Runner Status: {self._get_pod_phase(pod_name=pod_name)}' ) return pod_status def get_check_pod(self) -> str | None: """Get the name of the canonical pod for the check.""" return None def update_check_logs(self, pod_name: str | None = None): """Updates the check logs for the given pod name. Args: pod_name: The name of the pod to get the logs for. """ if pod_name is None: self.check_logs = f'Fetching logs is not defined for check: {self.name}' return try: self.check_logs = self._v1.read_namespaced_pod_log( name=pod_name, namespace='default', container=self.check_container_name, ) except client.rest.ApiException: # This error is expected if the pod is not found, and should fail quietly. pass def run( self, timeout_sec: int | None = None, startup_sec: int = 30, ) -> str | None: """Run the check. Args: timeout_sec: The timeout in seconds for the check. startup_sec: The time in seconds to wait for the health runner to start. Returns: The name of the health runner pod. """ click.echo(f'Performing {self.name} check...') if not timeout_sec: timeout_sec = self.timeout_sec if self.dry_run: click.echo( click.style( f'Running {self.name} check in dry run mode...', fg='red', bold=True, ) ) dry_run_command = self._check(sleep_sec=timeout_sec, dry_run=self.dry_run) click.echo(f'Skipping running command: {dry_run_command}') return health_runner_pod_name = self._check( sleep_sec=timeout_sec, ) start_time = time.time() # CLI has extra startup time to allow health runner to complete & clean up click.echo('Waiting for Health Runner to start...') update_hr_startup_interval_sec = 10 running_statuses = ( 'Running', 'Succeeded', ) while self._get_pod_phase(health_runner_pod_name) not in running_statuses: # Health Runner has some time to start up but if not started by startup # time, then give a warning if time.time() - start_time >= startup_sec: click.echo( click.style( text=( f'Health Runner not started after {startup_sec} seconds.\n' 'Health Runner may not cleanly exit even if it starts now.' ), fg='red', bold=True, ) ) break else: time.sleep(update_hr_startup_interval_sec) # Resets the time for progress bar since given HR startup time above start_time = time.time() update_interval_sec = 5 check_pod = self.get_check_pod() with click.progressbar( label=f'{self.name} Health Runner', length=timeout_sec, item_show_func=self._progress_bar_item_show, ) as progress_bar: while ( self._get_pod_phase(health_runner_pod_name) not in ['Succeeded', 'Failed', 'Unknown'] and time.time() - start_time < timeout_sec ): progress_bar.update( n_steps=update_interval_sec, current_item=health_runner_pod_name, ) if not check_pod: check_pod = self.get_check_pod() self.update_check_logs(pod_name=check_pod) time.sleep(update_interval_sec) progress_bar.update( update_interval_sec, current_item=health_runner_pod_name, ) return health_runner_pod_name