def delete_queued_resource()

in xlml/utils/tpu.py [0:0]


def delete_queued_resource(qualified_name: airflow.XComArg):
  """Implements cascading delete for a Queued Resource.

  Args:
    qualified_name: XCom value holding the qualified name of the queued
      resource.
  """

  @task(trigger_rule='all_done')
  def delete_tpu_nodes_request(qualified_name: str):
    # TODO(wcromar): Find a less repetitive way to manage the TPU client.
    creds, _ = google.auth.default()
    client = tpu_api.TpuClient(credentials=creds)

    try:
      qr = client.get_queued_resource(name=qualified_name)
    except google.api_core.exceptions.NotFound:
      logging.info(f'{qualified_name} not found')
      return

    for node in qr.tpu.node_spec:
      try:
        op = client.delete_node(name=f'{node.parent}/nodes/{node.node_id}')
        logging.info(f'Delete node state: {op}')
      except google.api_core.exceptions.NotFound:
        logging.info(f'{node.node_id} is already deleted')

  @task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
  def wait_for_tpu_deletion(qualified_name: str):
    creds, _ = google.auth.default()
    client = tpu_api.TpuClient(credentials=creds)

    try:
      qr = client.get_queued_resource(name=qualified_name)
    except google.api_core.exceptions.NotFound:
      logging.info(
          f'{qualified_name} was removed by cleanup DAG or deleted unexpectedly'
      )
      return True

    # Queued Resources can only be deleted once they are SUSPENDED, even if all
    # underlying nodes have already been deleted.
    if qr.state.state in [
        tpu_api.QueuedResourceState.State.SUSPENDED,
        # TPU will be sitting in WAITING_FOR_RESOURCES if creation timed out.
        tpu_api.QueuedResourceState.State.WAITING_FOR_RESOURCES,
        tpu_api.QueuedResourceState.State.ACCEPTED,
    ]:
      logging.info(f'All TPU nodes deleted for {qualified_name}')
      return True

    logging.info(f'TPU Nodes: {qr.tpu.node_spec}')
    return False

  @task(trigger_rule='all_done')
  def delete_queued_resource_request(qualified_name: str) -> Optional[str]:
    creds, _ = google.auth.default()
    client = tpu_api.TpuClient(credentials=creds)

    try:
      op = client.delete_queued_resource(name=qualified_name)
      logging.info(f'delete op {op}')
    except google.api_core.exceptions.NotFound:
      logging.info(f'{qualified_name} is already deleted')
      return None

    return op.operation.name

  @task.sensor(poke_interval=60, timeout=3600, mode='reschedule')
  def wait_for_queued_resource_deletion(op_name: Optional[str]):
    if not op_name:
      logging.info('No delete operation given')
      return True

    creds, _ = google.auth.default()
    client = tpu_api.TpuClient(credentials=creds)

    op = client.get_operation(operations.GetOperationRequest(name=op_name))
    return op.done

  delete_tpu_nodes = delete_tpu_nodes_request(
      qualified_name
  ) >> wait_for_tpu_deletion(qualified_name)
  qr_op_name = delete_tpu_nodes >> delete_queued_resource_request(
      qualified_name
  )
  wait_for_queued_resource_deletion(qr_op_name)