xlml/utils/tpu.py (333 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to create, delete, and SSH with TPUs."""
import datetime
import io
import itertools
import os
from typing import Dict, Iterable, Optional, Tuple, Union
import uuid
from absl import logging
import airflow
from airflow.decorators import task, task_group
from airflow.utils.task_group import TaskGroup
from airflow.operators.python import get_current_context
from airflow.models import Variable
from xlml.apis import gcp_config, test_config
from xlml.utils import ssh, startup_script
import fabric
import google.api_core.exceptions
import google.auth
import google.cloud.tpu_v2alpha1 as tpu_api
import google.longrunning.operations_pb2 as operations
import paramiko
from google.protobuf.duration_pb2 import Duration
TTL = 'ttl'
@task
def generate_tpu_name(
base_tpu_name: str,
set_env_var: bool,
) -> str:
tpu_name = f'{base_tpu_name}-{str(uuid.uuid4())}'
if set_env_var:
Variable.set(base_tpu_name, tpu_name)
return tpu_name
def create_queued_resource(
tpu_name: airflow.XComArg,
gcp: gcp_config.GCPConfig,
ssh_keys: airflow.XComArg,
timeout: datetime.timedelta,
task_test_config: Union[
test_config.TpuVmTest, test_config.JSonnetTpuVmTest
],
use_startup_script: bool = False,
) -> Tuple[TaskGroup, airflow.XComArg]:
"""Request a QueuedResource and wait until the nodes are created.
Args:
tpu_name: XCom value for unique TPU name.
accelerator: Description of TPU to create.
gcp: GCP project/zone configuration.
ssh_keys: XCom value for SSH keys to communicate with these TPUs.
timeout: Amount of time to wait for TPUs to be created.
task_test_config: Test config of the task.
use_startup_script: Indicator to use startup script.
Returns:
A TaskGroup for the entire create operation and an XCom value for the
qualified queued_resource name.
"""
@task
def create_queued_resource_request(
tpu_name: str, ssh_keys: ssh.SshKeys
) -> str:
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
parent = f'projects/{gcp.project_name}/locations/{gcp.zone}'
# Determine node_id and multiNodeParams based on num_slices
if task_test_config.num_slices == 1:
node_id = tpu_name
multi_node_params = None
else:
node_id = None
multi_node_params = (
tpu_api.types.QueuedResource.Tpu.NodeSpec.MultiNodeParams(
node_count=task_test_config.num_slices, node_id_prefix=tpu_name
)
)
startup_script_command = ''
if use_startup_script:
main_command = '\n'.join(
task_test_config.set_up_cmds + task_test_config.run_model_cmds
)
startup_script_command = startup_script.generate_startup_script(
main_command
)
metadata = {
'ssh-keys': f'ml-auto-solutions:{ssh_keys.public}',
'startup-script': startup_script_command,
}
create_tpu_timeout_in_sec = int(timeout.total_seconds())
if task_test_config.timeout:
run_model_timeout_in_sec = int(task_test_config.timeout.total_seconds())
else:
run_model_timeout_in_sec = 7200 # Assume a default timeout of 2 hours
# Time to live (ttl) is combination of:
# 1) tpu provision timeout
# 2) tpu run model timeout
# 3) 1 hour buffer timeout (provision, post_process, etc)
ttl = create_tpu_timeout_in_sec + run_model_timeout_in_sec + 3600
labels = {
TTL: str(ttl),
}
accelerator = task_test_config.accelerator
queued_resource = tpu_api.QueuedResource(
# TODO(ranran): enable configuration via `AcceleratorConfig`
tpu=tpu_api.QueuedResource.Tpu(
node_spec=[
tpu_api.QueuedResource.Tpu.NodeSpec(
node_id=node_id,
multi_node_params=multi_node_params,
parent=parent,
node=tpu_api.Node(
accelerator_type=accelerator.name,
description='noteardown',
runtime_version=accelerator.runtime_version,
network_config=tpu_api.NetworkConfig(
network=accelerator.network,
subnetwork=accelerator.subnetwork,
enable_external_ips=True,
),
metadata=metadata,
labels=labels,
scheduling_config=tpu_api.SchedulingConfig(
preemptible=accelerator.preemptible,
reserved=accelerator.reserved,
),
),
)
],
),
guaranteed=tpu_api.QueuedResource.Guaranteed(
reserved=accelerator.reserved,
),
queueing_policy=tpu_api.QueuedResource.QueueingPolicy(
valid_until_duration=Duration(seconds=int(timeout.total_seconds())),
),
)
qr_operation = client.create_queued_resource(
parent=parent,
queued_resource_id=tpu_name,
queued_resource=queued_resource,
)
response = qr_operation.result()
logging.info(f'Create QR response: {response}')
# TODO(wcromar): do anything about failures
return response.name
@task.sensor(
poke_interval=60, timeout=timeout.total_seconds(), mode='reschedule'
)
def wait_for_ready_queued_resource(qualified_name: str):
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
qr = client.get_queued_resource(name=qualified_name)
state = qr.state.state
logging.info(f'Queued resource state: {state.name}')
if qr.state.state == tpu_api.QueuedResourceState.State.ACTIVE:
return True
elif qr.state.state in [
tpu_api.QueuedResourceState.State.CREATING,
tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES,
tpu_api.QueuedResourceState.State.ACCEPTED,
tpu_api.QueuedResourceState.State.PROVISIONING,
]:
return False
else:
raise RuntimeError(f'Bad queued resource state {state.name}')
def check_if_startup_script_end(
queued_resource: airflow.XComArg, ssh_keys: airflow.XComArg
):
check_script = startup_script.monitor_startup_script()
return ssh_tpu.override(
task_id='check_if_startup_script_end',
execution_timeout=task_test_config.timeout,
owner=task_test_config.task_owner,
)(
queued_resource,
check_script,
ssh_keys,
False,
)
with TaskGroup(group_id='create_queued_resource') as tg:
qualified_name = create_queued_resource_request(tpu_name, ssh_keys)
if use_startup_script:
wait_for_ready_queued_resource(
qualified_name
) >> check_if_startup_script_end(qualified_name, ssh_keys)
else:
wait_for_ready_queued_resource(qualified_name)
return tg, qualified_name
@task_group
def delete_queued_resource(qualified_name: airflow.XComArg):
"""Implements cascading delete for a Queued Resource.
Args:
qualified_name: XCom value holding the qualified name of the queued
resource.
"""
@task(trigger_rule='all_done')
def delete_tpu_nodes_request(qualified_name: str):
# TODO(wcromar): Find a less repetitive way to manage the TPU client.
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
qr = client.get_queued_resource(name=qualified_name)
except google.api_core.exceptions.NotFound:
logging.info(f'{qualified_name} not found')
return
for node in qr.tpu.node_spec:
try:
op = client.delete_node(name=f'{node.parent}/nodes/{node.node_id}')
logging.info(f'Delete node state: {op}')
except google.api_core.exceptions.NotFound:
logging.info(f'{node.node_id} is already deleted')
@task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
def wait_for_tpu_deletion(qualified_name: str):
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
qr = client.get_queued_resource(name=qualified_name)
except google.api_core.exceptions.NotFound:
logging.info(
f'{qualified_name} was removed by cleanup DAG or deleted unexpectedly'
)
return True
# Queued Resources can only be deleted once they are SUSPENDED, even if all
# underlying nodes have already been deleted.
if qr.state.state in [
tpu_api.QueuedResourceState.State.SUSPENDED,
# TPU will be sitting in WAITING_FOR_RESOURCES if creation timed out.
tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES,
tpu_api.QueuedResourceState.State.ACCEPTED,
]:
logging.info(f'All TPU nodes deleted for {qualified_name}')
return True
logging.info(f'TPU Nodes: {qr.tpu.node_spec}')
return False
@task(trigger_rule='all_done')
def delete_queued_resource_request(qualified_name: str) -> Optional[str]:
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
op = client.delete_queued_resource(name=qualified_name)
logging.info(f'delete op {op}')
except google.api_core.exceptions.NotFound:
logging.info(f'{qualified_name} is already deleted')
return None
return op.operation.name
@task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
def wait_for_queued_resource_deletion(op_name: Optional[str]):
if not op_name:
logging.info('No delete operation given')
return True
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
op = client.get_operation(operations.GetOperationRequest(name=op_name))
return op.done
delete_tpu_nodes = delete_tpu_nodes_request(
qualified_name
) >> wait_for_tpu_deletion(qualified_name)
qr_op_name = delete_tpu_nodes >> delete_queued_resource_request(
qualified_name
)
wait_for_queued_resource_deletion(qr_op_name)
def kill_process_by_pid() -> str:
return f"""accelerator_type=\${{1}}
if [[ \${{accelerator_type}} =~ ^v5.* ]]
then
device_name=vfio/*
else
device_name=accel*
fi
echo \\"Terminating all processes utilizing the TPU (if any).\\"
sudo lsof -t /dev/\${{device_name}} | xargs -r kill -9
"""
@task
def ssh_tpu(
qualified_name: str,
cmds: Iterable[str],
ssh_keys: ssh.SshKeys,
all_workers: bool,
env: Dict[str, str] = None,
) -> None:
"""SSH TPU and run commands in multi process.
Args:
qualified_name: The qualified name of a queued resource.
cmds: The commands to run on a TPU.
ssh_keys: The SSH key pair to use for authentication.
all_workers: The flag to define if run commands on all workers or worker 0
only.
env: environment variables to be pass to the ssh runner session using dict.
"""
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
queued_resource = client.get_queued_resource(name=qualified_name)
nodes = [
client.get_node(name=os.path.join(node.parent, 'nodes', node.node_id))
for node in queued_resource.tpu.node_spec
]
if all_workers:
endpoints = itertools.chain.from_iterable(
node.network_endpoints for node in nodes
)
else:
endpoints = [nodes[0].network_endpoints[0]]
use_external_ips = os.getenv('XLMLTEST_SSH_EXTERNAL_IPS', '0') == '1'
if use_external_ips:
ip_addresses = [
endpoint.access_config.external_ip for endpoint in endpoints
]
else:
ip_addresses = [endpoint.ip_address for endpoint in endpoints]
logging.info(f'Connecting to IP addresses of workers: {ip_addresses}')
pkey = paramiko.RSAKey.from_private_key(io.StringIO(ssh_keys.private))
ssh_group = fabric.ThreadingGroup(
*ip_addresses,
connect_kwargs={
'auth_strategy': paramiko.auth_strategy.InMemoryPrivateKey(
'ml-auto-solutions', pkey
),
# See https://stackoverflow.com/a/59453832
'banner_timeout': 200,
},
# Proxy required on Cloudtops to connect to external IPs
gateway='corp-ssh-helper %h %p' if use_external_ips else None,
)
context = get_current_context()
if context['task_instance'].try_number > 1:
# kill TPU process by pid (if any) to avoid `TPU in use` error in retry
tmp_file = '/tmp/kill_process.sh'
accelerator_type = nodes[0].accelerator_type
script = kill_process_by_pid()
kill_process_cmds = (
f'set -xue; sudo echo "{script}" > {tmp_file}',
f'bash {tmp_file} {accelerator_type}',
)
ssh_group.run(';'.join(kill_process_cmds))
# run provided commands
ssh_group.run(cmds, env=env)
@task
def clean_up_idle_queued_resources(
project_name: str, zones: Iterable[str]
) -> None:
"""Clean up queued resources in FAILED or SUSPENDED states.
Args:
project_name: The project of resources.
zones: Available zones to clean up for the project.
"""
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
logging.info(f'Cleaning up resources in project {project_name}.')
for zone in zones:
logging.info(f'Checking in zone {zone.value}.')
parent = f'projects/{project_name}/locations/{zone.value}'
request = tpu_api.types.ListQueuedResourcesRequest(parent=parent)
responses = client.list_queued_resources(request)
for qr in responses:
state = qr.state.state
if (
state == tpu_api.QueuedResourceState.State.FAILED
or state == tpu_api.QueuedResourceState.State.SUSPENDED
):
logging.info(f'Deleting {qr.name} in {state.name} status.')
client.delete_queued_resource(name=qr.name)
@task
def clean_up_idle_nodes(project_name: str, zones: Iterable[str]) -> None:
"""Clean up TPU nodes that are expired.
Args:
project_name: The project of resources.
zones: Available zones to clean up for the project.
"""
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
logging.info(f'Cleaning up nodes in project {project_name}.')
for zone in zones:
logging.info(f'Checking in zone {zone.value}.')
parent = f'projects/{project_name}/locations/{zone.value}'
request = tpu_api.types.ListNodesRequest(parent=parent)
responses = client.list_nodes(request)
for node in responses:
ttl = int(node.labels[TTL]) if TTL in node.labels else None
if ttl:
create_time = node.create_time
current_time = datetime.datetime.now(datetime.timezone.utc)
logging.info(
(
f'Checking node {node.name}: create_time is {create_time},'
f' and current_time is {current_time}.'
)
)
active_time = current_time - create_time
delta = active_time.total_seconds() - ttl
if delta > 0:
datetime_delta = str(datetime.timedelta(seconds=delta))
logging.info(
(
f'Deleting node {node.name} due to exceeding its time to'
f' live (TTL) by {datetime_delta}.'
)
)
client.delete_node(name=node.name)