in xlml/utils/tpu.py [0:0]
def delete_queued_resource(qualified_name: airflow.XComArg):
"""Implements cascading delete for a Queued Resource.
Args:
qualified_name: XCom value holding the qualified name of the queued
resource.
"""
@task(trigger_rule='all_done')
def delete_tpu_nodes_request(qualified_name: str):
# TODO(wcromar): Find a less repetitive way to manage the TPU client.
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
qr = client.get_queued_resource(name=qualified_name)
except google.api_core.exceptions.NotFound:
logging.info(f'{qualified_name} not found')
return
for node in qr.tpu.node_spec:
try:
op = client.delete_node(name=f'{node.parent}/nodes/{node.node_id}')
logging.info(f'Delete node state: {op}')
except google.api_core.exceptions.NotFound:
logging.info(f'{node.node_id} is already deleted')
@task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
def wait_for_tpu_deletion(qualified_name: str):
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
qr = client.get_queued_resource(name=qualified_name)
except google.api_core.exceptions.NotFound:
logging.info(
f'{qualified_name} was removed by cleanup DAG or deleted unexpectedly'
)
return True
# Queued Resources can only be deleted once they are SUSPENDED, even if all
# underlying nodes have already been deleted.
if qr.state.state in [
tpu_api.QueuedResourceState.State.SUSPENDED,
# TPU will be sitting in WAITING_FOR_RESOURCES if creation timed out.
tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES,
tpu_api.QueuedResourceState.State.ACCEPTED,
]:
logging.info(f'All TPU nodes deleted for {qualified_name}')
return True
logging.info(f'TPU Nodes: {qr.tpu.node_spec}')
return False
@task(trigger_rule='all_done')
def delete_queued_resource_request(qualified_name: str) -> Optional[str]:
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
try:
op = client.delete_queued_resource(name=qualified_name)
logging.info(f'delete op {op}')
except google.api_core.exceptions.NotFound:
logging.info(f'{qualified_name} is already deleted')
return None
return op.operation.name
@task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
def wait_for_queued_resource_deletion(op_name: Optional[str]):
if not op_name:
logging.info('No delete operation given')
return True
creds, _ = google.auth.default()
client = tpu_api.TpuClient(credentials=creds)
op = client.get_operation(operations.GetOperationRequest(name=op_name))
return op.done
delete_tpu_nodes = delete_tpu_nodes_request(
qualified_name
) >> wait_for_tpu_deletion(qualified_name)
qr_op_name = delete_tpu_nodes >> delete_queued_resource_request(
qualified_name
)
wait_for_queued_resource_deletion(qr_op_name)