#!/usr/bin/env python

# Copyright 2024 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""schedule-daemon.py is a Topology-aware Kubernetes pod scheduler."""

import argparse
import collections
import itertools
import logging
import re
import time
from typing import Any
import kubernetes
import kubernetes.client
import kubernetes.client.models
import kubernetes.client.rest
from kubernetes.utils import quantity

# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Set the root logger level to INFO
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('my_app.log'),  # Log to a file
        logging.StreamHandler(),  # Log to the console
    ],
)

# labels for GKE<1.31, require labeler sidecar
PRERELEASE_CLUSTER_LABEL = 'topology.gke.io/cluster'
PRERELEASE_RACK_LABEL = 'topology.gke.io/rack'
PRERELEASE_HOST_LABEL = 'topology.gke.io/host'
# labels for GKE>=1.31
CLUSTER_LABEL = 'cloud.google.com/gce-topology-block'
RACK_LABEL = 'cloud.google.com/gce-topology-subblock'
HOST_LABEL = 'cloud.google.com/gce-topology-host'

# Kubernetes labels used to identify jobs and their completion indices.
# These labels are typically set by Kubernetes controllers like JobSet or
# CronJob controllers. The job completion index is used to track the progress
# of a job and ensure that pods are scheduled in the correct order.
JOB_COMPLETION_INDEX_LABEL_KEY = 'batch.kubernetes.io/job-completion-index'
JOB_NAME_LABEL_KEY = 'job-name'

# Most ML workloads are launched on batch jobs or JobSet API. To support
# kubeflow operators (eg MPIJobs) we introduce the alternative labels to read
KUBEFLOW_REPLICA_INDEX_LABEL_KEY = 'training.kubeflow.org/replica-index'
KUBEFLOW_JOB_NAME_LABEL_KEY = 'training.kubeflow.org/job-name'

# collection of methods to extract job name from pod metadata
UNKNOWN_JOB_NAME = 'jobless'  # in case when job name is not found


def extract_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
  """Extracts the job name label from a pod.

  Args:
    pod: The pod to extract the job name.

  Returns:
    The job name label, or UNKNOWN_JOB_NAME if the label is not found.
  """
  return pod.metadata.labels.get(JOB_NAME_LABEL_KEY, UNKNOWN_JOB_NAME)


def extract_kubeflow_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
  """Extracts the kubeflow job name label from a pod.

  Args:
    pod: The pod to extract the job name.

  Returns:
    The kubeflow job name label, or UNKNOWN_JOB_NAME if the label is not found.
  """
  return pod.metadata.labels.get(KUBEFLOW_JOB_NAME_LABEL_KEY, UNKNOWN_JOB_NAME)


def extract_owner_reference_uid(pod: kubernetes.client.models.V1Pod) -> str:
  """Extracts the owner reference UID from a pod.

  Args:
    pod: The pod to extract the job name.

  Returns:
    The owner reference UID, or UNKNOWN_JOB_NAME if the owner reference is not
    found.
  """
  if pod.metadata.owner_references:
    return pod.metadata.owner_references[0].uid
  return UNKNOWN_JOB_NAME


def extract_helm_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
  """Extracts the helm job name label from a pod.

  Args:
    pod: The pod to extract the job name.

  Returns:
    The helm job name label, or UNKNOWN_JOB_NAME if the label is not found.
  """
  if pod.metadata.labels:
    return pod.metadata.labels.get('name', UNKNOWN_JOB_NAME)
  return UNKNOWN_JOB_NAME


def pod_sorting_key(pod_info: dict[str, Any]) -> tuple[str, int]:
  """Returns key/rank to be used for sorting pods.

  Given that numbers is often suffixed for multi-node deployments,
  here we use a (prefix, number) tuple for the sorting key.
  This means "xxx-pod2" should appear before "xxx-pod10"

  Args:
    pod_info: The pod info.

  Returns:
    A tuple of the pod's prefix and index.
  """

  try:
    index_value = pod_info.get('index')
    if index_value is not None:
      return int(index_value)
  except (ValueError, TypeError) as e:
    logging.exception(
        'Error converting %s pod index to integer: %s', pod_info['name'], e
    )
  # if the suffix is a number, extract it from the name
  name = pod_info['name']
  if match := re.fullmatch(r'^(.*?)(\d+)$', name):
    prefix, suffix = match.groups()
    return (prefix, int(suffix))
  else:
    logging.warning(
        'Pod %s does not have a numeric suffix. Using 0 as index.', name
    )
    return (name, 0)  # No numeric suffix


def node_topology_distance(node1: dict[str, Any], node2: dict[str, Any]) -> int:
  """Calculates the distance between two nodes in the topology.

  Args:
    node1: The first node.
    node2: The second node.

  Returns:
    The distance between the two nodes, or 0 if the nodes are in the same
    topology block. It's also 0, if topology labels are missing.
  """
  # distance between cluster/rack/host tuples
  node1_topology_keys = node_topology_key(node1)
  node2_topology_keys = node_topology_key(node2)
  result = 1000000
  for i, node1_topology_key in enumerate(node1_topology_keys):
    if node1_topology_key != node2_topology_keys[i]:
      return result
    result /= 100
  return 0


def node_topology_key(
    node: dict[str, Any],
) -> tuple[str, str, str] | tuple[()]:
  """Extract topology labels of a node.

  Args:
    node: The node.

  Returns:
    A tuple of the node's topology labels, or an empty tuple if the node does
    not have topology labels.
  """
  node_labels = node['node_labels']
  for labels in [
      (CLUSTER_LABEL, RACK_LABEL, HOST_LABEL),
      (PRERELEASE_CLUSTER_LABEL, PRERELEASE_RACK_LABEL, PRERELEASE_HOST_LABEL),
  ]:
    if all(label in node_labels for label in labels):
      return tuple(node_labels[label] for label in labels)
  logging.info('Node %s does not have topology labels', node['name'])
  return ()


def get_pod_used_resources(
    pod: kubernetes.client.models.V1Pod,
) -> tuple[int, int, int]:
  """Get the resources used by this pod.

  Args:
    pod: The pod.

  Returns:
    A tuple of the pod's used CPU, memory, and GPU.
  """
  used_cpu, used_memory, used_gpu = 0, 0, 0
  if pod.status is None or pod.status.container_statuses is None:
    return used_cpu, used_memory, used_gpu
  for container, container_status in zip(
      pod.spec.containers, pod.status.container_statuses
  ):
    if container_status.state.terminated is not None:
      # terminated pods don't use resources
      continue
    requests = container.resources.requests or {}
    used_cpu += quantity.parse_quantity(requests.get('cpu', 0))
    used_memory += quantity.parse_quantity(requests.get('memory', 0))
    used_gpu += int(requests.get('nvidia.com/gpu', 0))
  return used_cpu, used_memory, used_gpu


def all_pods_have_same_tolerations(
    pods: list[kubernetes.client.models.V1Pod],
) -> bool:
  """Checks if all pods in the list have the same tolerations.

   Edge case: Empty list is considered to have having same tolerations.

  Args:
    pods: A list of V1Pod objects.

  Returns:
    True if all pods have the same tolerations, False otherwise.
  """
  if not pods:
    return True

  first_pod_tolerations = pods[0].spec.tolerations
  return all(pod.spec.tolerations == first_pod_tolerations for pod in pods[1:])


def find_schedulable_nodes(
    nodes: list[kubernetes.client.models.V1Node],
    other_running_pods: list[kubernetes.client.models.V1Pod],
    new_pods_to_schedule: dict[str, dict[str, Any]],
) -> dict[str, Any]:
  """Finds nodes that can be scheduled.

  Args:
    nodes: A list of V1Node objects.
    other_running_pods: A list of running pods. They occupy node resources.
    new_pods_to_schedule: Info dict of pods to schedule.

  Returns:
    A dict of node names to node infos for node where scheduling is possible.
  """
  nodes_info = {}
  # guarantee: we checked above that all pods have same tolerations
  tolerated_taint_dict = {}
  tolerated_taints = next(iter(new_pods_to_schedule.values()))[
      'spec'
  ].tolerations
  if tolerated_taints:
    tolerated_taint_dict = {t.key: t for t in tolerated_taints}

  for node in nodes:
    node_name = node.metadata.name
    # skip nodes that have taints that are not covered by pod tolerations
    if any(
        t.key not in tolerated_taint_dict
        or (
            tolerated_taint_dict[t.key].operator == 'Equal'
            and tolerated_taint_dict[t.key].value != t.value
        )
        for t in node.spec.taints or []
    ):
      logging.info(
          'Skipping node %s because it has taints (%s) which are not covered'
          ' by pod tolerations %s',
          node_name,
          node.spec.taints,
          tolerated_taint_dict,
      )
      continue

    # skip nodes that is not in Ready state
    if any(
      condition.type == "Ready" and condition.status != "True" for condition in node.status.conditions
      ):
        logging.info(
          'Skipping node %s because it is NotReady',
          node_name
        )
        continue

    allocatable = node.status.allocatable
    used_cpu, used_memory, used_gpu = 0, 0, 0

    for pod in other_running_pods:
      if pod.spec.node_name and pod.spec.node_name == node_name:
        cpu, mem, gpu = get_pod_used_resources(pod)
        used_cpu += cpu
        used_memory += mem
        used_gpu += gpu

    free_cpu = quantity.parse_quantity(allocatable['cpu']) - used_cpu
    free_memory = quantity.parse_quantity(allocatable['memory']) - used_memory
    free_gpu = int(allocatable.get('nvidia.com/gpu', 0)) - used_gpu

    node_info = {
        'name': node_name,
        'cpu': free_cpu,
        'memory': free_memory,
        'gpu': free_gpu,
    }
    if node.metadata.labels:
      node_info['node_labels'] = node.metadata.labels
    nodes_info[node_name] = node_info

    logging.info(
        'Node: %s, CPU: %s, Memory: %s, GPU: %s, Topology: %s',
        node_name,
        free_cpu,
        free_memory,
        free_gpu,
        node_topology_key(node_info),
    )

  return nodes_info


def find_pod_gates(
    pods: list[kubernetes.client.models.V1Pod], prefix: str
) -> set[str]:
  """Finds pods with scheduling gates that starts with the prefix.

  Args:
    pods: A list of V1Pod objects.
    prefix: The prefix of the scheduling gate.

  Returns:
    A set of scheduling gate names.
  """
  s = set()
  for pod in pods:
    if pod.spec.scheduling_gates:
      for g in pod.spec.scheduling_gates:
        if g.name.startswith(prefix):
          s.add(g.name)
  return s


def find_schedulable_pods(
    job_name: str,
    pods: list[kubernetes.client.models.V1Pod],
) -> dict[str, dict[str, Any]]:
  """Finds pods that can be scheduled for a given gate.

  Args:
    job_name: The name of the job.
    pods: A list of V1Pod objects.

  Returns:
    A dict of pod names to pod infos for schedulable pods.
  """
  pods_to_schedule = {}
  for pod in pods:
    pod_name = pod.metadata.name
    pod_index = None
    if JOB_COMPLETION_INDEX_LABEL_KEY in pod.metadata.labels:
      pod_index = pod.metadata.labels[JOB_COMPLETION_INDEX_LABEL_KEY]
    elif KUBEFLOW_REPLICA_INDEX_LABEL_KEY in pod.metadata.labels:
      pod_index = pod.metadata.labels[KUBEFLOW_REPLICA_INDEX_LABEL_KEY]
    else:
      # it's not hard stop, we still can order pods by name
      logging.info(
          'No index in pod %s metadata. Ordering derived from name.',
          pod_name,
      )

    used_cpu, used_memory, used_gpu = 0, 0, 0
    for container in pod.spec.containers:
      requests = container.resources.requests or {}
      used_cpu += quantity.parse_quantity(requests.get('cpu', 0))
      used_memory += quantity.parse_quantity(requests.get('memory', 0))
      used_gpu += int(requests.get('nvidia.com/gpu', 0))

    pod_info = {
        'name': pod_name,
        'namespace': pod.metadata.namespace,
        'index': pod_index,
        'cpu': used_cpu,
        'memory': used_memory,
        'gpu': used_gpu,
        'spec': pod.spec,
        'metadata': pod.metadata,
        'job_name': job_name,
        'creation_time': pod.metadata.creation_timestamp,
    }
    if pod.spec.node_selector:
      pod_info['node_selector'] = pod.spec.node_selector
    pods_to_schedule[pod_name] = pod_info

    logging.info(
        'Found schedulable pod: %s/%s, CPU: %s, Memory: %s, GPU: %s Index: %s',
        pod.metadata.namespace,
        pod_name,
        used_cpu,
        used_memory,
        used_gpu,
        pod_index,
    )

  return pods_to_schedule


def can_schedule(node: dict[str, Any], pod: dict[str, Any]) -> bool:
  """Checks if a given pod can be scheduled on a given node.

  The decision is based on resource node availability and pod requriements. The
  node_selector is also checked, if the pod has it.

  Args:
    node: The node to check if the pod can be scheduled on.
    pod: The pod to check if it can be scheduled on the node.

  Returns:
    True if the pod can be scheduled on the node, False otherwise.
  """
  node_selector = pod.get('node_selector', {})
  node_labels = node.get('node_labels', {})

  for key, value in node_selector.items():
    if key not in node_labels or node_labels[key] != value:
      return False

  return (
      node['cpu'] >= pod['cpu']
      and node['memory'] >= pod['memory']
      and node['gpu'] >= pod['gpu']
  )


def schedule_pod_on_node(
    v1: kubernetes.client.CoreV1Api,
    pod_name: str,
    pod_namespace: str,
    node: dict[str, Any],
    gate_name: str,
) -> bool:
  """Schedules a pod on a given node using affinity for direct assignment.

  Args:
    v1: The kubernetes client.
    pod_name: The name of the pod to schedule.
    pod_namespace: The namespace of the pod to schedule.
    node: The node to schedule the pod on.
    gate_name: The name of the gate to remove from the pod.

  Returns:
    True if the pod was scheduled on the node, False otherwise.
  """
  try:
    pod = v1.read_namespaced_pod(pod_name, pod_namespace)
    if any(gate.name == gate_name for gate in pod.spec.scheduling_gates):
      new_gates = [
          gate for gate in pod.spec.scheduling_gates if gate.name != gate_name
      ]
      pod.spec.affinity = {
          'nodeAffinity': {
              'requiredDuringSchedulingIgnoredDuringExecution': {
                  'nodeSelectorTerms': [{
                      'matchExpressions': [{
                          'key': 'kubernetes.io/hostname',
                          'operator': 'In',
                          'values': [node['name']],
                      }]
                  }]
              }
          }
      }
      pod.spec.scheduling_gates = new_gates

      v1.replace_namespaced_pod(pod_name, pod_namespace, pod)

      logging.info(
          'Pod %s/%s scheduled on %s with topology %s', pod_namespace, pod_name, node['name'], node_topology_key(node)
      )
  except kubernetes.client.rest.ApiException as e:
    logging.exception(
        'Exception when removing pod %s scheduling gate: %s', pod_name, e
    )
    return False
  return True


def calculate_pods_assignment(
    sorted_nodes: list[dict[str, Any]], sorted_pods: list[dict[str, Any]]
) -> list[int]:
  """Gets the best pod assignment by minimizing the topology distance.

  Args:
    sorted_nodes: A list of sorted node infos.
    sorted_pods: A list of sorted pod infos.

  Returns:
    A list of node indices that the pods should be assigned to.
  """
  assignment = [-i for i in reversed(range(1, len(sorted_pods) + 1))]
  best_assignment = []
  minimum_distance = float('inf')

  while True:
    all_ok = True
    i = len(assignment) - 1
    while i >= 0 and all_ok:
      assignment[i] += 1
      if assignment[i] == len(sorted_nodes):
        break  # no more nodes to consider
      if assignment[i] >= 0 and can_schedule(
          sorted_nodes[assignment[i]], sorted_pods[i]
      ):
        i -= 1
      elif i < len(assignment) - 1 and assignment[i] == assignment[i + 1] - 1:
        all_ok = False  # to ignore the half of triangle of all combinations

    # we checked all combinations, return what found so far
    if assignment[-1] == len(sorted_nodes):
      break
    # calculate total distance of the detected viable assignment
    if all_ok:
      new_distance = 0
      for i in range(1, len(sorted_pods)):
        new_distance += node_topology_distance(
            sorted_nodes[assignment[i]], sorted_nodes[assignment[i - 1]]
        )
      if new_distance < minimum_distance:
        best_assignment = assignment.copy()
        minimum_distance = new_distance

  return best_assignment


def list_pods(
    v1: kubernetes.client.CoreV1Api, state_filter: str = None
) -> list[kubernetes.client.models.V1Pod]:
  """Lists pods in all namespaces.

  Args:
    v1: The kubernetes client.
    state_filter: The pod state to filter by.

  Returns:
    A list of V1Pod objects.
  """
  pods = []
  for n in v1.list_namespace().items:
    namespace = n.metadata.name
    for pod in v1.list_namespaced_pod(namespace).items:
      if not state_filter or pod.status.phase == state_filter:
        pods.append(pod)
  return pods


def schedule_pod_with_gate(
    v1: kubernetes.client.CoreV1Api,
    gate_name: str,
) -> None:
  """Find and schedule pods with a given gate.

  Args:
    v1: k8s client
    gate_name: name of the gate to schedule pods with (see k8s yaml, eg
      schedulingGates=gke.io/topology-aware-auto-*`)
  """
  # query the pods again after the sleep, just in case not all gated pods
  # are returned from previous query
  pending_pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Pending')
  pods_for_given_gate: list[kubernetes.client.models.V1Pod] = []
  for p in pending_pods:
    if p.spec.scheduling_gates:
      if gate_name in {g.name for g in p.spec.scheduling_gates}:
        pods_for_given_gate.append(p)

  pods_per_job: dict[str, list[kubernetes.client.models.V1Pod]] = (
      collections.defaultdict(list)
  )
  jobless_pod_names: set[str] = set(
      [p.metadata.name for p in pods_for_given_gate]
  )
  for job_name_extractor in [
      extract_job_name_label,
      extract_kubeflow_job_name_label,
      extract_owner_reference_uid,
      extract_helm_job_name_label,
  ]:
    for job_name, job_pod_group in itertools.groupby(
        pods_for_given_gate,
        job_name_extractor,
    ):
      # ignore pod groups with not yet recognized job name
      # (we have 3 iters/different ways to extract it, see outer loop)
      if job_name == UNKNOWN_JOB_NAME:
        continue
      pod_group = list(job_pod_group)
      # sanity check: all pods have creation_timestamp
      if not all(p.metadata.creation_timestamp for p in pod_group):
        logging.error(
            'No pod creation_timestamp in job %s. Job ignored.Pods: %s',
            job_name,
            ', '.join([p.metadata.name for p in pod_group]),
        )
        # exclude such bad pods from consideration
        jobless_pod_names -= {p.metadata.name for p in pod_group}
        continue
      # sanity check: all pods have same tolerations
      if not all_pods_have_same_tolerations(pod_group):
        logging.error(
            'Pods in job %s have different tolerations. Job ignored.Pods: %s',
            job_name,
            ', '.join([p.metadata.name for p in pod_group]),
        )
        # exclude such bad pods from consideration
        jobless_pod_names -= {p.metadata.name for p in pod_group}
        continue
      pod_names_in_group = set([p.metadata.name for p in pod_group])
      # sanity check: if all job pods has been scheduled, we can ignore the job
      if all(p not in jobless_pod_names for p in pod_names_in_group):
        logging.info(
            'All %d pods of job %s are already scheduled. Job ignored.',
            len(pod_group),
            job_name,
        )
        continue

      # all checks passed
      jobless_pod_names -= pod_names_in_group
      pods_per_job[job_name] = list(pod_group)
      logging.info(
          'Found %d pods for job %s, created: %s',
          len(pods_per_job[job_name]),
          job_name,
          pods_per_job[job_name][0].metadata.creation_timestamp,
      )

  # sanity check
  if jobless_pod_names:
    logging.warning(
        'Found %d pods without explicit job name, going to schedule all'
        ' together. Pods: %s',
        len(jobless_pod_names),
        ', '.join(list(jobless_pod_names)),
    )
    # only for robustness we try to schedule them separately
    pods_per_job['pods-without-explicit-job-name'] = [
        p for p in pods_for_given_gate if p.metadata.name in jobless_pod_names
    ]

  logging.info('Start scheduling %d jobs', len(pods_per_job))

  nodes: list[kubernetes.client.models.V1Node] = v1.list_node().items
  # Sort job names based on the creation time of the first pod in each job group
  sorted_job_names = sorted(
      pods_per_job.keys(),
      key=lambda job_name: pods_per_job[job_name][
          0
      ].metadata.creation_timestamp,
  )
  running_pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Running')
  logging.info('Already running pods number: %d', len(running_pods))
  recently_scheduled_nodes_names = set()
  for job_name in sorted_job_names:
    job_pods: list[kubernetes.client.models.V1Pod] = pods_per_job.get(
        job_name, []
    )
    logging.info(
        'Attempting to schedule job: %s with %s pods ', job_name, len(job_pods)
    )
    try:
      schedulable_pods_infos: dict[str, Any] = find_schedulable_pods(
          job_name, job_pods
      )

      # filter out nodes that are already occupied by running pods
      nodes = [
          node
          for node in nodes
          if node.metadata.name not in recently_scheduled_nodes_names
      ]
      nodes_infos: dict[str, Any] = find_schedulable_nodes(
          nodes, running_pods, schedulable_pods_infos
      )

      if len(schedulable_pods_infos) > len(nodes_infos):
        logging.error(
            'Not enough nodes available for job %s scheduling: %s nodes, %s'
            ' pods. Skipping job.',
            job_name,
            len(nodes_infos),
            len(schedulable_pods_infos),
        )
        continue

      logging.info(
          'Available nodes for job %s scheduling: %s',
          job_name,
          len(nodes_infos),
      )

      sorted_pods = sorted(schedulable_pods_infos.values(), key=pod_sorting_key)
      sorted_nodes = sorted(nodes_infos.values(), key=node_topology_key)
      best_assignment = calculate_pods_assignment(sorted_nodes, sorted_pods)

      if not best_assignment:
        logging.error(
            'No scheduling for job %s with gate %s was found. Skipping job.',
            job_name,
            gate_name,
        )
        continue
      logging.info(
          'Assignment found, scheduling %s with %s pods.',
          job_name,
          len(job_pods),
      )

      for i, pod in enumerate(sorted_pods):
        node = sorted_nodes[best_assignment[i]]
        if not schedule_pod_on_node(
            v1, pod['name'], pod['namespace'], node, gate_name
        ):
          logging.error(
              'Failed to schedule pod %s on node %s. Skipping job %s',
              pod['name'],
              node['name'],
              job_name,
          )
          break
          # revisit: in case of failure clear up partially scheduled pods
        recently_scheduled_nodes_names.add(node['name'])

    except Exception as e:  # pylint: disable=broad-except
      logging.exception(
          'Exception when scheduling %s job,gate %s: %s', job_name, gate_name, e
      )


def run_scheduling_loop() -> None:
  """Runs scheduling.

  This function runs a infinite loop that periodically schedules pods with
  topology-aware  scheduling gates.
  """
  parser = argparse.ArgumentParser(prog='schedule-workload.py')

  parser.add_argument(
      '-g', '--gate', default='gke.io/topology-aware-auto-'
  )  # prefix of the schedule gate
  parser.add_argument(
      '-i', '--interval', default=1.0
  )  # intervals (in seconds) between scheduling
  parser.add_argument(
      '--ignored-namespace', nargs='*', default=[]
  )  # namespace to search for pods
  args = parser.parse_args()

  try:
    kubernetes.config.load_incluster_config()
  except kubernetes.config.ConfigException:
    kubernetes.config.load_kube_config()
  v1 = kubernetes.client.CoreV1Api()
  # wait needed during container restart to allow previously scheduled pods
  # to be visible on nodes and occupy resources for their correct estimates
  logging.info('[Cool off] 90sec')
  time.sleep(90.0)
  try:
    last_run_ts = time.time()
    while True:
      time_since_prev_run = time.time() - last_run_ts
      if time_since_prev_run < args.interval:
        logging.info('[Cool off] %ssec', args.interval - time_since_prev_run)
        time.sleep(args.interval - time_since_prev_run)
      last_run_ts = time.time()

      # Get pods to schedule
      pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Pending')

      gates = find_pod_gates(pods, args.gate)
      logging.info('Found %s pending pods and %s gates', len(pods), len(gates))

      if not gates:
        # No pods to be scheduled
        continue

      # sleep for 5 seconds, assuming that all pods within one group would be
      # all visible by then
      logging.info('[Cool off] 5sec')
      time.sleep(5.0)

      for g in gates:
        logging.info('Scheduling pods with gate %s', g)
        schedule_pod_with_gate(v1, g)
        logging.info('[Cool off] 60sec')
        time.sleep(60.0)  # cool off

  except kubernetes.client.rest.ApiException as e:
    logging.exception('Exception when listing Kubernetes nodes or pods: %s', e)


if __name__ == '__main__':
  run_scheduling_loop()
