xlml/utils/gpu.py (381 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to create, delete, and SSH with GPUs.""" from __future__ import annotations from absl import logging import airflow from airflow.decorators import task, task_group import datetime import fabric from google.cloud import compute_v1 import io import paramiko import re import time from typing import Dict, Iterable import uuid from xlml.apis import gcp_config, test_config from xlml.utils import ssh def get_image_from_family(project: str, family: str) -> compute_v1.Image: """ Retrieve the newest image that is part of a given family in a project. Args: project: project ID or project number of the Cloud project to get image. family: name of the image family you want to get image from. Returns: An Image object. """ image_client = compute_v1.ImagesClient() # List of public operating system (OS) images: # https://cloud.google.com/compute/docs/images/os-details newest_image = image_client.get_from_family(project=project, family=family) return newest_image def disk_from_image( disk_type: str, boot: bool, source_image: str, disk_size_gb: int = 100, auto_delete: bool = True, ) -> compute_v1.AttachedDisk: """ Create an AttachedDisk object to be used in VM instance creation. Uses an image as the source for the new disk. Args: disk_type: the type of disk you want to create. This value uses the following format: "zones/{zone}/diskTypes/(pd-standard|pd-ssd|pd-balanced|pd-extreme)". For example: "zones/us-west3-b/diskTypes/pd-ssd" disk_size_gb: size of the new disk in gigabytes boot: boolean flag indicating whether this disk should be used as a boot disk of an instance source_image: source image to use when creating this disk. You must have read access to this disk. This can be one of the publicly available images or an image from one of your projects. This value uses the following format: "projects/{project_name}/global/images/{image_name}" auto_delete: boolean flag indicating whether this disk should be deleted with the VM that uses it Returns: AttachedDisk object configured to be created using the specified image. """ boot_disk = compute_v1.AttachedDisk() initialize_params = compute_v1.AttachedDiskInitializeParams() initialize_params.source_image = source_image initialize_params.disk_size_gb = disk_size_gb initialize_params.disk_type = disk_type boot_disk.initialize_params = initialize_params # Remember to set auto_delete to True if you want the disk to be # deleted when you delete your VM instance. boot_disk.auto_delete = auto_delete boot_disk.boot = boot return boot_disk def local_ssd_disk(zone: str) -> compute_v1.AttachedDisk: """ Create an AttachedDisk object to be used in VM instance creation. The created disk contains no data and requires formatting before it can be used. Args: zone: The zone in which the local SSD drive will be attached. Returns: AttachedDisk object configured as a local SSD disk. """ disk = compute_v1.AttachedDisk(interface="NVME") disk.type_ = compute_v1.AttachedDisk.Type.SCRATCH.name initialize_params = compute_v1.AttachedDiskInitializeParams() initialize_params.disk_type = f"zones/{zone}/diskTypes/local-ssd" disk.initialize_params = initialize_params disk.auto_delete = True return disk def create_metadata(key_val: Dict[str, str]) -> compute_v1.Metadata: metadata = compute_v1.Metadata() metadata.items = [{"key": key, "value": val} for key, val in key_val.items()] return metadata @task def generate_gpu_name() -> str: # note: GPU vm name need to match regex # '(?:[a-z](?:[-a-z0-9]{0,61}[a-z0-9])?)', while TPU vm allows '_'. return # f'{base_gpu_name}-{str(uuid.uuid4())}'.replace('_', '-') # If we use the above base_gpu_name in the return, some potion of the can be # longer than 61 as in the regex. return f"gpu-{str(uuid.uuid4())}" @task def get_existing_resource( instance_name: str, ssh_keys: ssh.SshKeys, gcp: gcp_config.GCPConfig, ) -> airflow.XComArg: """Reach a resource node that is already created. Args: instance_name: name of the existing instance. ssh_keys: airflow.XComArg, gcp: GCP project/zone configuration. Returns: The ip address of the GPU VM. """ instance_client = compute_v1.InstancesClient() instance_request = compute_v1.GetInstanceRequest( instance=instance_name, project=gcp.project_name, zone=gcp.zone, ) instance = instance_client.get(request=instance_request) logging.info( f"Resource retrieve status: {instance.status}, {instance.status_message}" ) ip_address = instance.network_interfaces[0].network_i_p metadata = instance.metadata items = metadata.items or [] ssh_key_exist = False for item in metadata.items: if item.key == "ssh-keys": ssh_key_exist = True item.value = ( item.value + "\n" + f"cloud-ml-auto-solutions:{ssh_keys.public}" ) break if not ssh_key_exist: items.append({ "key": "ssh-keys", "value": f"cloud-ml-auto-solutions:{ssh_keys.public}", }) metadata.items = items metadata_request = compute_v1.SetMetadataInstanceRequest( instance=instance_name, project=gcp.project_name, zone=gcp.zone, metadata_resource=metadata, ) operation = instance_client.set_metadata(request=metadata_request) if operation.error: logging.error( ( "Error during instance set metadata: [Code:" f" {operation.http_error_status_code}]:" f" {operation.http_error_message}" f" {operation.error}" ), ) raise operation.exception() or RuntimeError(operation.http_error_message) elif operation.warnings: logging.warning("Warnings during instance set metadata:\n") for warning in operation.warnings: logging.warning(f" - {warning.code}: {warning.message}") return ip_address @task(trigger_rule="all_done") def clean_up_ssh_keys( instance_name: str, ssh_keys: ssh.SshKeys, gcp: gcp_config.GCPConfig, ) -> airflow.XComArg: """Remove the generated one-time use ssh_keys from existing instance. Args: instance_name: name of the existing instance. ssh_keys: airflow.XComArg, gcp: GCP project/zone configuration. """ instance_client = compute_v1.InstancesClient() instance_request = compute_v1.GetInstanceRequest( instance=instance_name, project=gcp.project_name, zone=gcp.zone, ) instance = instance_client.get(request=instance_request) logging.info( f"Resource get status: {instance.status}, {instance.status_message}" ) metadata = instance.metadata for item in metadata.items: if item.key == "ssh-keys": item.value = item.value.replace( f"\ncloud-ml-auto-solutions:{ssh_keys.public}", "" ) break metadata_request = compute_v1.SetMetadataInstanceRequest( instance=instance_name, project=gcp.project_name, zone=gcp.zone, metadata_resource=metadata, ) operation = instance_client.set_metadata(request=metadata_request) if operation.error: logging.error( ( "Error during instance set metadata: [Code:" f" {operation.http_error_status_code}]:" f" {operation.http_error_message}" f" {operation.error}" ), ) raise operation.exception() or RuntimeError(operation.http_error_message) elif operation.warnings: logging.warning("Warnings during instance set metadata:\n") for warning in operation.warnings: logging.warning(f" - {warning.code}: {warning.message}") @task_group def create_resource( gpu_name: airflow.XComArg, image_project: str, image_family: str, accelerator: test_config.Gpu, gcp: gcp_config.GCPConfig, ssh_keys: airflow.XComArg, timeout: datetime.timedelta, install_nvidia_drivers: bool = False, reservation: bool = False, ) -> airflow.XComArg: """Request a resource and wait until the nodes are created. Args: gpu_name: XCom value for unique GPU name. image_project: project of the image. image_family: family of the image. accelerator: Description of GPU to create. gcp: GCP project/zone configuration. ssh_kpeys: XCom value for SSH keys to communicate with these GPUs. timeout: Amount of time to wait for GPUs to be created. install_nvidia_drivers: Whether to install Nvidia drivers. reservation: Whether to use an existing reservation Returns: The ip address of the GPU VM. """ project_id = gcp.project_name zone = gcp.zone @task def create_resource_request( instance_name: str, accelerator: test_config.Gpu, ssh_keys: ssh.SshKeys, instance_termination_action: str, external_access=True, spot: bool = False, delete_protection: bool = False, install_nvidia_drivers: bool = False, reservation: bool = False, ) -> airflow.XComArg: """ Send an instance creation request to the Compute Engine API and wait for it to complete. Args: instance_name: name of the new virtual machine (VM) instance. accelerator: Description of GPU to create. ssh_keys: XCom value for SSH keys to communicate with these GPUs. instance_termination_action: What action should be taken once a Spot VM is terminated. Possible values: "STOP", "DELETE" external_access: boolean flag indicating if the instance should have an external IPv4 address assigned. spot: boolean value indicating if the new instance should be a Spot VM or not. delete_protection: boolean value indicating if the new virtual machine should be protected against deletion or not. install_nvidia_drivers: boolean value indicating whether to install Nvidia drivers. reservation: boolean value indicating whether to use VM reservation Returns: Ip address of the instance object created. """ machine_type = accelerator.machine_type image = get_image_from_family(project=image_project, family=image_family) disk_type = f"zones/{gcp.zone}/diskTypes/pd-ssd" disks = [ disk_from_image( disk_type, True, image.self_link, accelerator.disk_size_gb ) ] if accelerator.attach_local_ssd: for _ in range(accelerator.count): disks.append(local_ssd_disk(gcp.zone)) metadata = create_metadata({ "install-nvidia-driver": str(install_nvidia_drivers), "proxy-mode": "project_editors", "ssh-keys": f"cloud-ml-auto-solutions:{ssh_keys.public}", }) accelerators = [ compute_v1.AcceleratorConfig( accelerator_count=accelerator.count, accelerator_type=( f"projects/{gcp.project_name}/zones/{gcp.zone}/" f"acceleratorTypes/{accelerator.accelerator_type}" ), ) ] service_account = compute_v1.ServiceAccount( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) instance_client = compute_v1.InstancesClient() # Use the network interface provided in the network_link argument. network_interface = compute_v1.NetworkInterface() if accelerator.subnetwork: network_interface.network = accelerator.network if accelerator.subnetwork: network_interface.subnetwork = accelerator.subnetwork if external_access: access = compute_v1.AccessConfig() access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name access.name = "External NAT" access.network_tier = access.NetworkTier.PREMIUM.name network_interface.access_configs = [access] # Collect information into the Instance object. instance = compute_v1.Instance() instance.network_interfaces = [network_interface] instance.name = instance_name instance.disks = disks if re.match(r"^zones/[a-z\d\-]+/machineTypes/[a-z\d\-]+$", machine_type): instance.machine_type = machine_type else: instance.machine_type = f"zones/{zone}/machineTypes/{machine_type}" instance.scheduling = compute_v1.Scheduling() if accelerators: instance.guest_accelerators = accelerators instance.scheduling.on_host_maintenance = ( compute_v1.Scheduling.OnHostMaintenance.TERMINATE.name ) if metadata: instance.metadata = metadata if service_account: instance.service_accounts = [service_account] if spot: # Set the Spot VM setting instance.scheduling.provisioning_model = ( compute_v1.Scheduling.ProvisioningModel.SPOT.name ) instance.scheduling.instance_termination_action = ( instance_termination_action ) if delete_protection: # Set the delete protection bit instance.deletion_protection = True if reservation: # Set reservation affinity if specified reservation_affinity = compute_v1.ReservationAffinity() reservation_affinity.consume_reservation_type = ( compute_v1.ReservationAffinity.ConsumeReservationType.ANY_RESERVATION.name ) instance.reservation_affinity = reservation_affinity # Prepare the request to insert an instance. request = compute_v1.InsertInstanceRequest() request.zone = zone request.project = project_id request.instance_resource = instance # Wait for the create operation to complete. logging.info(f"Creating the {instance_name} instance in {zone}...") operation = instance_client.insert(request=request) return operation.name @task.sensor( poke_interval=60, timeout=timeout.total_seconds(), mode="reschedule" ) def wait_for_resource_creation(operation_name: airflow.XComArg): # Retrives the delete opeartion to check the status. client = compute_v1.ZoneOperationsClient() request = compute_v1.GetZoneOperationRequest( operation=operation_name, project=project_id, zone=zone, ) operation = client.get(request=request) status = operation.status.name if status in ("RUNNING", "PENDING"): logging.info( f"Resource create status: {status}, {operation.status_message}" ) return False else: if operation.error: logging.error( ( "Error during resource creation: [Code:" f" {operation.http_error_status_code}]:" f" {operation.http_error_message}" f" {operation.error}" ), ) raise operation.exception() or RuntimeError( operation.http_error_message ) elif operation.warnings: logging.warning("Warnings during resource creation:\n") for warning in operation.warnings: logging.warning(f" - {warning.code}: {warning.message}") return True @task def get_ip_address(instance: str) -> airflow.XComArg: # It takes time to be able to use the ssh with the ip address # even though the creation request is complete. We intentionally # sleep for 60s to wait for the ip address to be accessible. time.sleep(60) instance_client = compute_v1.InstancesClient() instance = instance_client.get( project=project_id, zone=zone, instance=instance ) if len(instance.network_interfaces) > 1: logging.warning( f"GPU instance {gpu_name} has more than one network interface." ) return instance.network_interfaces[0].network_i_p operation = create_resource_request( instance_name=gpu_name, accelerator=accelerator, ssh_keys=ssh_keys, instance_termination_action="STOP", install_nvidia_drivers=install_nvidia_drivers, reservation=reservation, ) ip_address = get_ip_address(gpu_name) wait_for_resource_creation(operation) >> ip_address return ip_address @task def ssh_host( ip_address: str, cmds: Iterable[str], ssh_keys: ssh.SshKeys, env: Dict[str, str] = None, ) -> None: """SSH GPU and run commands in multi process. Args: ip_address: The ip address of the vm resource. cmds: The commands to run on a GPU. ssh_keys: The SSH key pair to use for authentication. env: environment variables to be pass to the ssh runner session using dict. """ pkey = paramiko.RSAKey.from_private_key(io.StringIO(ssh_keys.private)) logging.info(f"Connecting to IP addresses {ip_address}") ssh_group = fabric.ThreadingGroup( ip_address, user="cloud-ml-auto-solutions", connect_kwargs={ "auth_strategy": paramiko.auth_strategy.InMemoryPrivateKey( "cloud-ml-auto-solutions", pkey ) }, ) ssh_group.run(cmds, env=env) @task_group def delete_resource(instance_name: airflow.XComArg, project_id: str, zone: str): @task(trigger_rule="all_done") def delete_resource_request( instance_name: str, project_id: str, zone: str ) -> airflow.XComArg: client = compute_v1.InstancesClient() request = compute_v1.DeleteInstanceRequest( instance=instance_name, project=project_id, zone=zone, ) operation = client.delete(request=request) return operation.name @task.sensor(poke_interval=60, timeout=1800, mode="reschedule") def wait_for_resource_deletion(operation_name: airflow.XComArg): # Retrives the delete opeartion to check the status. client = compute_v1.ZoneOperationsClient() request = compute_v1.GetZoneOperationRequest( operation=operation_name, project=project_id, zone=zone, ) operation = client.get(request=request) status = operation.status.name if status in ("RUNNING", "PENDING"): logging.info( f"Resource deletion status: {status}, {operation.status_message}" ) return False else: if operation.error: logging.error( ( "Error during resource deletion: [Code:" f" {operation.http_error_status_code}]:" f" {operation.http_error_message}" ), ) logging.error(f"Operation ID: {operation.name}") raise operation.exception() or RuntimeError( operation.http_error_message ) elif operation.warnings: logging.warning("Warnings during resource deletion:\n") for warning in operation.warnings: logging.warning(f" - {warning.code}: {warning.message}") return True op = delete_resource_request(instance_name, project_id, zone) wait_for_resource_deletion(op)