import base64
import concurrent.futures
import datetime
import logging
import tempfile
import time
from typing import Any, Dict, Optional

from airflow.decorators import task, task_group
import google.auth
import google.auth.transport.requests
from google.cloud import container_v1
import kubernetes

from xlml.apis import gcp_config

"""Utilities for GKE."""


class PodsNotReadyError(Exception):
  """Exception raised when pods are not ready within the expected timeout."""

  def __init__(self, message):
    super().__init__(message)


def get_authenticated_client(
    project_name: str, region: str, cluster_name: str
) -> kubernetes.client.ApiClient:
  container_client = container_v1.ClusterManagerClient()
  cluster_path = (
      f'projects/{project_name}/locations/{region}/clusters/{cluster_name}'
  )
  response = container_client.get_cluster(name=cluster_path)
  creds, _ = google.auth.default()
  auth_req = google.auth.transport.requests.Request()
  creds.refresh(auth_req)
  configuration = kubernetes.client.Configuration()
  configuration.host = f'https://{response.endpoint}'

  ca_cert_content = base64.b64decode(
      response.master_auth.cluster_ca_certificate
  )
  with tempfile.NamedTemporaryFile(delete=False) as ca_cert:
    ca_cert.write(ca_cert_content)
    configuration.ssl_ca_cert = ca_cert.name
  configuration.api_key_prefix['authorization'] = 'Bearer'
  configuration.api_key['authorization'] = creds.token

  return kubernetes.client.ApiClient(configuration)


@task_group
def run_job(
    body: Dict[str, Any],
    gcp: gcp_config.GCPConfig,
    cluster_name: str,
    job_create_timeout: datetime.timedelta,
    gcs_location: str = '',
):
  """Run a batch job directly on a GKE cluster.

  Args:
    body: Dict that defines a Kubernetes `Job`.
    gcp: GCP config with the project name and zone of the GKE cluster.
    cluster_name: Name of the GCP cluster.
    job_create_timeout: Amount of time to wait for all pods to become active.
  """

  @task
  def deploy_job(gcs_location):
    body['spec']['template']['spec']['containers'][0]['env'].append(
        {'name': 'GCS_OUTPUT', 'value': gcs_location}
    )
    client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name)

    jobs_client = kubernetes.client.BatchV1Api(client)

    resp = jobs_client.create_namespaced_job(namespace='default', body=body)

    logging.info(f'response: {resp}')

    return resp.metadata.name

  @task.sensor(
      poke_interval=60,
      timeout=job_create_timeout.total_seconds(),
      mode='reschedule',
  )
  def wait_all_pods_ready(name: str):
    client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name)

    batch_api = kubernetes.client.BatchV1Api(client)
    job = batch_api.read_namespaced_job(namespace='default', name=name)

    # TODO(wcromar): Handle other conditions (e.g. unschedulablility)
    logging.info(f'Job status: {job.status}')
    if job.status.failed:
      raise RuntimeError(f'Job has {job.status.failed} failed pods.')

    core_api = kubernetes.client.CoreV1Api(client)
    pod_label_selector = f'batch.kubernetes.io/job-name={name}'
    pods = core_api.list_namespaced_pod(
        namespace='default', label_selector=pod_label_selector
    )

    if len(pods.items) != body['spec']['parallelism']:
      logging.info('Waiting for all pods to be created...')
      return False

    return True

  @task(retries=6)
  def stream_logs(name: str):
    def _watch_pod(name, namespace) -> Optional[int]:
      logs_watcher = kubernetes.watch.Watch()

      logging.info(f'Waiting for pod {name} to start...')
      pod_watcher = kubernetes.watch.Watch()
      for event in pod_watcher.stream(
          core_api.list_namespaced_pod,
          namespace,
          field_selector=f'metadata.name={name}',
      ):
        status = event['object'].status
        logging.info(
            f'Pod {event["object"].metadata.name} status: {status.phase}'
        )
        if status.phase != 'Pending':
          break

      logging.info(f'Streaming pod logs for {name}...')
      for line in logs_watcher.stream(
          core_api.read_namespaced_pod_log,
          name,
          namespace,
          _request_timeout=3600,
      ):
        logging.info(f'{name}] {line}')

      logging.warning(f'Lost logs stream for {name}.')

      pod = core_api.read_namespaced_pod(namespace='default', name=name)
      if pod.status.container_statuses:
        container_status = pod.status.container_statuses[0]
        if pod.status.container_statuses[0].state.terminated:
          exit_code = container_status.state.terminated.exit_code
          if exit_code:
            logging.error(f'Pod {name} had non-zero exit code {exit_code}')

          return exit_code

      logging.warning(f'Unknown status for pod {name}')
      return None

    # We need to re-authenticate if the stream_logs fail. This can happen when
    # the job runs for too long and the credential expire.
    client = get_authenticated_client(gcp.project_name, gcp.zone, cluster_name)

    batch_api = kubernetes.client.BatchV1Api(client)
    core_api = kubernetes.client.CoreV1Api(client)
    pod_label_selector = f'batch.kubernetes.io/job-name={name}'
    pods = core_api.list_namespaced_pod(
        namespace='default', label_selector=pod_label_selector
    )
    # TODO(piz): Use time.sleep may not be a good solution here. However, I expect
    # resources are all ready in wait_all_pods_ready stage. This just in case
    # authentication takes time. Check with Will for better solutions.
    time.sleep(30)
    if len(pods.items) != body['spec']['parallelism']:
      logging.info('Waiting for all pods to be re-connected...')
      raise PodsNotReadyError('pods are not ready after refreshing credential.')

    with concurrent.futures.ThreadPoolExecutor() as executor:
      futures = []
      for pod in pods.items:
        f = executor.submit(
            _watch_pod, pod.metadata.name, pod.metadata.namespace
        )
        futures.append(f)

      # Wait for pods to complete, and exit with the first non-zero exit code.
      for f in concurrent.futures.as_completed(futures):
        try:
          # TODO(piz/wcromar): it looks like there is a delay between as_completed
          # and update of f.result(). exit_code can be None even task is complete.
          exit_code = f.result()
        except kubernetes.client.ApiException as e:
          logging.error('Kubernetes error. Retrying...', exc_info=e)
          exit_code = None

        # Retry if status is unknown
        if exit_code is None:
          raise RuntimeError('unknown exit code')
        if exit_code:
          raise RuntimeError('Non-zero exit code')

  name = deploy_job(gcs_location)
  wait_all_pods_ready(name) >> stream_logs(name)


def zone_to_region(zone: str) -> str:
  zone_terms = zone.split('-')
  return zone_terms[0] + '-' + zone_terms[1]
