xlml/utils/tpu.py (333 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to create, delete, and SSH with TPUs.""" import datetime import io import itertools import os from typing import Dict, Iterable, Optional, Tuple, Union import uuid from absl import logging import airflow from airflow.decorators import task, task_group from airflow.utils.task_group import TaskGroup from airflow.operators.python import get_current_context from airflow.models import Variable from xlml.apis import gcp_config, test_config from xlml.utils import ssh, startup_script import fabric import google.api_core.exceptions import google.auth import google.cloud.tpu_v2alpha1 as tpu_api import google.longrunning.operations_pb2 as operations import paramiko from google.protobuf.duration_pb2 import Duration TTL = 'ttl' @task def generate_tpu_name( base_tpu_name: str, set_env_var: bool, ) -> str: tpu_name = f'{base_tpu_name}-{str(uuid.uuid4())}' if set_env_var: Variable.set(base_tpu_name, tpu_name) return tpu_name def create_queued_resource( tpu_name: airflow.XComArg, gcp: gcp_config.GCPConfig, ssh_keys: airflow.XComArg, timeout: datetime.timedelta, task_test_config: Union[ test_config.TpuVmTest, test_config.JSonnetTpuVmTest ], use_startup_script: bool = False, ) -> Tuple[TaskGroup, airflow.XComArg]: """Request a QueuedResource and wait until the nodes are created. Args: tpu_name: XCom value for unique TPU name. accelerator: Description of TPU to create. gcp: GCP project/zone configuration. ssh_keys: XCom value for SSH keys to communicate with these TPUs. timeout: Amount of time to wait for TPUs to be created. task_test_config: Test config of the task. use_startup_script: Indicator to use startup script. Returns: A TaskGroup for the entire create operation and an XCom value for the qualified queued_resource name. """ @task def create_queued_resource_request( tpu_name: str, ssh_keys: ssh.SshKeys ) -> str: creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) parent = f'projects/{gcp.project_name}/locations/{gcp.zone}' # Determine node_id and multiNodeParams based on num_slices if task_test_config.num_slices == 1: node_id = tpu_name multi_node_params = None else: node_id = None multi_node_params = ( tpu_api.types.QueuedResource.Tpu.NodeSpec.MultiNodeParams( node_count=task_test_config.num_slices, node_id_prefix=tpu_name ) ) startup_script_command = '' if use_startup_script: main_command = '\n'.join( task_test_config.set_up_cmds + task_test_config.run_model_cmds ) startup_script_command = startup_script.generate_startup_script( main_command ) metadata = { 'ssh-keys': f'ml-auto-solutions:{ssh_keys.public}', 'startup-script': startup_script_command, } create_tpu_timeout_in_sec = int(timeout.total_seconds()) if task_test_config.timeout: run_model_timeout_in_sec = int(task_test_config.timeout.total_seconds()) else: run_model_timeout_in_sec = 7200 # Assume a default timeout of 2 hours # Time to live (ttl) is combination of: # 1) tpu provision timeout # 2) tpu run model timeout # 3) 1 hour buffer timeout (provision, post_process, etc) ttl = create_tpu_timeout_in_sec + run_model_timeout_in_sec + 3600 labels = { TTL: str(ttl), } accelerator = task_test_config.accelerator queued_resource = tpu_api.QueuedResource( # TODO(ranran): enable configuration via `AcceleratorConfig` tpu=tpu_api.QueuedResource.Tpu( node_spec=[ tpu_api.QueuedResource.Tpu.NodeSpec( node_id=node_id, multi_node_params=multi_node_params, parent=parent, node=tpu_api.Node( accelerator_type=accelerator.name, description='noteardown', runtime_version=accelerator.runtime_version, network_config=tpu_api.NetworkConfig( network=accelerator.network, subnetwork=accelerator.subnetwork, enable_external_ips=True, ), metadata=metadata, labels=labels, scheduling_config=tpu_api.SchedulingConfig( preemptible=accelerator.preemptible, reserved=accelerator.reserved, ), ), ) ], ), guaranteed=tpu_api.QueuedResource.Guaranteed( reserved=accelerator.reserved, ), queueing_policy=tpu_api.QueuedResource.QueueingPolicy( valid_until_duration=Duration(seconds=int(timeout.total_seconds())), ), ) qr_operation = client.create_queued_resource( parent=parent, queued_resource_id=tpu_name, queued_resource=queued_resource, ) response = qr_operation.result() logging.info(f'Create QR response: {response}') # TODO(wcromar): do anything about failures return response.name @task.sensor( poke_interval=60, timeout=timeout.total_seconds(), mode='reschedule' ) def wait_for_ready_queued_resource(qualified_name: str): creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) qr = client.get_queued_resource(name=qualified_name) state = qr.state.state logging.info(f'Queued resource state: {state.name}') if qr.state.state == tpu_api.QueuedResourceState.State.ACTIVE: return True elif qr.state.state in [ tpu_api.QueuedResourceState.State.CREATING, tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES, tpu_api.QueuedResourceState.State.ACCEPTED, tpu_api.QueuedResourceState.State.PROVISIONING, ]: return False else: raise RuntimeError(f'Bad queued resource state {state.name}') def check_if_startup_script_end( queued_resource: airflow.XComArg, ssh_keys: airflow.XComArg ): check_script = startup_script.monitor_startup_script() return ssh_tpu.override( task_id='check_if_startup_script_end', execution_timeout=task_test_config.timeout, owner=task_test_config.task_owner, )( queued_resource, check_script, ssh_keys, False, ) with TaskGroup(group_id='create_queued_resource') as tg: qualified_name = create_queued_resource_request(tpu_name, ssh_keys) if use_startup_script: wait_for_ready_queued_resource( qualified_name ) >> check_if_startup_script_end(qualified_name, ssh_keys) else: wait_for_ready_queued_resource(qualified_name) return tg, qualified_name @task_group def delete_queued_resource(qualified_name: airflow.XComArg): """Implements cascading delete for a Queued Resource. Args: qualified_name: XCom value holding the qualified name of the queued resource. """ @task(trigger_rule='all_done') def delete_tpu_nodes_request(qualified_name: str): # TODO(wcromar): Find a less repetitive way to manage the TPU client. creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) try: qr = client.get_queued_resource(name=qualified_name) except google.api_core.exceptions.NotFound: logging.info(f'{qualified_name} not found') return for node in qr.tpu.node_spec: try: op = client.delete_node(name=f'{node.parent}/nodes/{node.node_id}') logging.info(f'Delete node state: {op}') except google.api_core.exceptions.NotFound: logging.info(f'{node.node_id} is already deleted') @task.sensor(poke_interval=60, timeout=3600, mode='reschedule') def wait_for_tpu_deletion(qualified_name: str): creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) try: qr = client.get_queued_resource(name=qualified_name) except google.api_core.exceptions.NotFound: logging.info( f'{qualified_name} was removed by cleanup DAG or deleted unexpectedly' ) return True # Queued Resources can only be deleted once they are SUSPENDED, even if all # underlying nodes have already been deleted. if qr.state.state in [ tpu_api.QueuedResourceState.State.SUSPENDED, # TPU will be sitting in WAITING_FOR_RESOURCES if creation timed out. tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES, tpu_api.QueuedResourceState.State.ACCEPTED, ]: logging.info(f'All TPU nodes deleted for {qualified_name}') return True logging.info(f'TPU Nodes: {qr.tpu.node_spec}') return False @task(trigger_rule='all_done') def delete_queued_resource_request(qualified_name: str) -> Optional[str]: creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) try: op = client.delete_queued_resource(name=qualified_name) logging.info(f'delete op {op}') except google.api_core.exceptions.NotFound: logging.info(f'{qualified_name} is already deleted') return None return op.operation.name @task.sensor(poke_interval=60, timeout=3600, mode='reschedule') def wait_for_queued_resource_deletion(op_name: Optional[str]): if not op_name: logging.info('No delete operation given') return True creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) op = client.get_operation(operations.GetOperationRequest(name=op_name)) return op.done delete_tpu_nodes = delete_tpu_nodes_request( qualified_name ) >> wait_for_tpu_deletion(qualified_name) qr_op_name = delete_tpu_nodes >> delete_queued_resource_request( qualified_name ) wait_for_queued_resource_deletion(qr_op_name) def kill_process_by_pid() -> str: return f"""accelerator_type=\${{1}} if [[ \${{accelerator_type}} =~ ^v5.* ]] then device_name=vfio/* else device_name=accel* fi echo \\"Terminating all processes utilizing the TPU (if any).\\" sudo lsof -t /dev/\${{device_name}} | xargs -r kill -9 """ @task def ssh_tpu( qualified_name: str, cmds: Iterable[str], ssh_keys: ssh.SshKeys, all_workers: bool, env: Dict[str, str] = None, ) -> None: """SSH TPU and run commands in multi process. Args: qualified_name: The qualified name of a queued resource. cmds: The commands to run on a TPU. ssh_keys: The SSH key pair to use for authentication. all_workers: The flag to define if run commands on all workers or worker 0 only. env: environment variables to be pass to the ssh runner session using dict. """ creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) queued_resource = client.get_queued_resource(name=qualified_name) nodes = [ client.get_node(name=os.path.join(node.parent, 'nodes', node.node_id)) for node in queued_resource.tpu.node_spec ] if all_workers: endpoints = itertools.chain.from_iterable( node.network_endpoints for node in nodes ) else: endpoints = [nodes[0].network_endpoints[0]] use_external_ips = os.getenv('XLMLTEST_SSH_EXTERNAL_IPS', '0') == '1' if use_external_ips: ip_addresses = [ endpoint.access_config.external_ip for endpoint in endpoints ] else: ip_addresses = [endpoint.ip_address for endpoint in endpoints] logging.info(f'Connecting to IP addresses of workers: {ip_addresses}') pkey = paramiko.RSAKey.from_private_key(io.StringIO(ssh_keys.private)) ssh_group = fabric.ThreadingGroup( *ip_addresses, connect_kwargs={ 'auth_strategy': paramiko.auth_strategy.InMemoryPrivateKey( 'ml-auto-solutions', pkey ), # See https://stackoverflow.com/a/59453832 'banner_timeout': 200, }, # Proxy required on Cloudtops to connect to external IPs gateway='corp-ssh-helper %h %p' if use_external_ips else None, ) context = get_current_context() if context['task_instance'].try_number > 1: # kill TPU process by pid (if any) to avoid `TPU in use` error in retry tmp_file = '/tmp/kill_process.sh' accelerator_type = nodes[0].accelerator_type script = kill_process_by_pid() kill_process_cmds = ( f'set -xue; sudo echo "{script}" > {tmp_file}', f'bash {tmp_file} {accelerator_type}', ) ssh_group.run(';'.join(kill_process_cmds)) # run provided commands ssh_group.run(cmds, env=env) @task def clean_up_idle_queued_resources( project_name: str, zones: Iterable[str] ) -> None: """Clean up queued resources in FAILED or SUSPENDED states. Args: project_name: The project of resources. zones: Available zones to clean up for the project. """ creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) logging.info(f'Cleaning up resources in project {project_name}.') for zone in zones: logging.info(f'Checking in zone {zone.value}.') parent = f'projects/{project_name}/locations/{zone.value}' request = tpu_api.types.ListQueuedResourcesRequest(parent=parent) responses = client.list_queued_resources(request) for qr in responses: state = qr.state.state if ( state == tpu_api.QueuedResourceState.State.FAILED or state == tpu_api.QueuedResourceState.State.SUSPENDED ): logging.info(f'Deleting {qr.name} in {state.name} status.') client.delete_queued_resource(name=qr.name) @task def clean_up_idle_nodes(project_name: str, zones: Iterable[str]) -> None: """Clean up TPU nodes that are expired. Args: project_name: The project of resources. zones: Available zones to clean up for the project. """ creds, _ = google.auth.default() client = tpu_api.TpuClient(credentials=creds) logging.info(f'Cleaning up nodes in project {project_name}.') for zone in zones: logging.info(f'Checking in zone {zone.value}.') parent = f'projects/{project_name}/locations/{zone.value}' request = tpu_api.types.ListNodesRequest(parent=parent) responses = client.list_nodes(request) for node in responses: ttl = int(node.labels[TTL]) if TTL in node.labels else None if ttl: create_time = node.create_time current_time = datetime.datetime.now(datetime.timezone.utc) logging.info( ( f'Checking node {node.name}: create_time is {create_time},' f' and current_time is {current_time}.' ) ) active_time = current_time - create_time delta = active_time.total_seconds() - ttl if delta > 0: datetime_delta = str(datetime.timedelta(seconds=delta)) logging.info( ( f'Deleting node {node.name} due to exceeding its time to' f' live (TTL) by {datetime_delta}.' ) ) client.delete_node(name=node.name)