gke-topology-scheduler/schedule-daemon.py (509 lines of code) (raw):
#!/usr/bin/env python
# Copyright 2024 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""schedule-daemon.py is a Topology-aware Kubernetes pod scheduler."""
import argparse
import collections
import itertools
import logging
import re
import time
from typing import Any
import kubernetes
import kubernetes.client
import kubernetes.client.models
import kubernetes.client.rest
from kubernetes.utils import quantity
# Configure logging
logging.basicConfig(
level=logging.INFO, # Set the root logger level to INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('my_app.log'), # Log to a file
logging.StreamHandler(), # Log to the console
],
)
# labels for GKE<1.31, require labeler sidecar
PRERELEASE_CLUSTER_LABEL = 'topology.gke.io/cluster'
PRERELEASE_RACK_LABEL = 'topology.gke.io/rack'
PRERELEASE_HOST_LABEL = 'topology.gke.io/host'
# labels for GKE>=1.31
CLUSTER_LABEL = 'cloud.google.com/gce-topology-block'
RACK_LABEL = 'cloud.google.com/gce-topology-subblock'
HOST_LABEL = 'cloud.google.com/gce-topology-host'
# Kubernetes labels used to identify jobs and their completion indices.
# These labels are typically set by Kubernetes controllers like JobSet or
# CronJob controllers. The job completion index is used to track the progress
# of a job and ensure that pods are scheduled in the correct order.
JOB_COMPLETION_INDEX_LABEL_KEY = 'batch.kubernetes.io/job-completion-index'
JOB_NAME_LABEL_KEY = 'job-name'
# Most ML workloads are launched on batch jobs or JobSet API. To support
# kubeflow operators (eg MPIJobs) we introduce the alternative labels to read
KUBEFLOW_REPLICA_INDEX_LABEL_KEY = 'training.kubeflow.org/replica-index'
KUBEFLOW_JOB_NAME_LABEL_KEY = 'training.kubeflow.org/job-name'
# collection of methods to extract job name from pod metadata
UNKNOWN_JOB_NAME = 'jobless' # in case when job name is not found
def extract_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
"""Extracts the job name label from a pod.
Args:
pod: The pod to extract the job name.
Returns:
The job name label, or UNKNOWN_JOB_NAME if the label is not found.
"""
return pod.metadata.labels.get(JOB_NAME_LABEL_KEY, UNKNOWN_JOB_NAME)
def extract_kubeflow_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
"""Extracts the kubeflow job name label from a pod.
Args:
pod: The pod to extract the job name.
Returns:
The kubeflow job name label, or UNKNOWN_JOB_NAME if the label is not found.
"""
return pod.metadata.labels.get(KUBEFLOW_JOB_NAME_LABEL_KEY, UNKNOWN_JOB_NAME)
def extract_owner_reference_uid(pod: kubernetes.client.models.V1Pod) -> str:
"""Extracts the owner reference UID from a pod.
Args:
pod: The pod to extract the job name.
Returns:
The owner reference UID, or UNKNOWN_JOB_NAME if the owner reference is not
found.
"""
if pod.metadata.owner_references:
return pod.metadata.owner_references[0].uid
return UNKNOWN_JOB_NAME
def extract_helm_job_name_label(pod: kubernetes.client.models.V1Pod) -> str:
"""Extracts the helm job name label from a pod.
Args:
pod: The pod to extract the job name.
Returns:
The helm job name label, or UNKNOWN_JOB_NAME if the label is not found.
"""
if pod.metadata.labels:
return pod.metadata.labels.get('name', UNKNOWN_JOB_NAME)
return UNKNOWN_JOB_NAME
def pod_sorting_key(pod_info: dict[str, Any]) -> tuple[str, int]:
"""Returns key/rank to be used for sorting pods.
Given that numbers is often suffixed for multi-node deployments,
here we use a (prefix, number) tuple for the sorting key.
This means "xxx-pod2" should appear before "xxx-pod10"
Args:
pod_info: The pod info.
Returns:
A tuple of the pod's prefix and index.
"""
try:
index_value = pod_info.get('index')
if index_value is not None:
return int(index_value)
except (ValueError, TypeError) as e:
logging.exception(
'Error converting %s pod index to integer: %s', pod_info['name'], e
)
# if the suffix is a number, extract it from the name
name = pod_info['name']
if match := re.fullmatch(r'^(.*?)(\d+)$', name):
prefix, suffix = match.groups()
return (prefix, int(suffix))
else:
logging.warning(
'Pod %s does not have a numeric suffix. Using 0 as index.', name
)
return (name, 0) # No numeric suffix
def node_topology_distance(node1: dict[str, Any], node2: dict[str, Any]) -> int:
"""Calculates the distance between two nodes in the topology.
Args:
node1: The first node.
node2: The second node.
Returns:
The distance between the two nodes, or 0 if the nodes are in the same
topology block. It's also 0, if topology labels are missing.
"""
# distance between cluster/rack/host tuples
node1_topology_keys = node_topology_key(node1)
node2_topology_keys = node_topology_key(node2)
result = 1000000
for i, node1_topology_key in enumerate(node1_topology_keys):
if node1_topology_key != node2_topology_keys[i]:
return result
result /= 100
return 0
def node_topology_key(
node: dict[str, Any],
) -> tuple[str, str, str] | tuple[()]:
"""Extract topology labels of a node.
Args:
node: The node.
Returns:
A tuple of the node's topology labels, or an empty tuple if the node does
not have topology labels.
"""
node_labels = node['node_labels']
for labels in [
(CLUSTER_LABEL, RACK_LABEL, HOST_LABEL),
(PRERELEASE_CLUSTER_LABEL, PRERELEASE_RACK_LABEL, PRERELEASE_HOST_LABEL),
]:
if all(label in node_labels for label in labels):
return tuple(node_labels[label] for label in labels)
logging.info('Node %s does not have topology labels', node['name'])
return ()
def get_pod_used_resources(
pod: kubernetes.client.models.V1Pod,
) -> tuple[int, int, int]:
"""Get the resources used by this pod.
Args:
pod: The pod.
Returns:
A tuple of the pod's used CPU, memory, and GPU.
"""
used_cpu, used_memory, used_gpu = 0, 0, 0
if pod.status is None or pod.status.container_statuses is None:
return used_cpu, used_memory, used_gpu
for container, container_status in zip(
pod.spec.containers, pod.status.container_statuses
):
if container_status.state.terminated is not None:
# terminated pods don't use resources
continue
requests = container.resources.requests or {}
used_cpu += quantity.parse_quantity(requests.get('cpu', 0))
used_memory += quantity.parse_quantity(requests.get('memory', 0))
used_gpu += int(requests.get('nvidia.com/gpu', 0))
return used_cpu, used_memory, used_gpu
def all_pods_have_same_tolerations(
pods: list[kubernetes.client.models.V1Pod],
) -> bool:
"""Checks if all pods in the list have the same tolerations.
Edge case: Empty list is considered to have having same tolerations.
Args:
pods: A list of V1Pod objects.
Returns:
True if all pods have the same tolerations, False otherwise.
"""
if not pods:
return True
first_pod_tolerations = pods[0].spec.tolerations
return all(pod.spec.tolerations == first_pod_tolerations for pod in pods[1:])
def find_schedulable_nodes(
nodes: list[kubernetes.client.models.V1Node],
other_running_pods: list[kubernetes.client.models.V1Pod],
new_pods_to_schedule: dict[str, dict[str, Any]],
) -> dict[str, Any]:
"""Finds nodes that can be scheduled.
Args:
nodes: A list of V1Node objects.
other_running_pods: A list of running pods. They occupy node resources.
new_pods_to_schedule: Info dict of pods to schedule.
Returns:
A dict of node names to node infos for node where scheduling is possible.
"""
nodes_info = {}
# guarantee: we checked above that all pods have same tolerations
tolerated_taint_dict = {}
tolerated_taints = next(iter(new_pods_to_schedule.values()))[
'spec'
].tolerations
if tolerated_taints:
tolerated_taint_dict = {t.key: t for t in tolerated_taints}
for node in nodes:
node_name = node.metadata.name
# skip nodes that have taints that are not covered by pod tolerations
if any(
t.key not in tolerated_taint_dict
or (
tolerated_taint_dict[t.key].operator == 'Equal'
and tolerated_taint_dict[t.key].value != t.value
)
for t in node.spec.taints or []
):
logging.info(
'Skipping node %s because it has taints (%s) which are not covered'
' by pod tolerations %s',
node_name,
node.spec.taints,
tolerated_taint_dict,
)
continue
# skip nodes that is not in Ready state
if any(
condition.type == "Ready" and condition.status != "True" for condition in node.status.conditions
):
logging.info(
'Skipping node %s because it is NotReady',
node_name
)
continue
allocatable = node.status.allocatable
used_cpu, used_memory, used_gpu = 0, 0, 0
for pod in other_running_pods:
if pod.spec.node_name and pod.spec.node_name == node_name:
cpu, mem, gpu = get_pod_used_resources(pod)
used_cpu += cpu
used_memory += mem
used_gpu += gpu
free_cpu = quantity.parse_quantity(allocatable['cpu']) - used_cpu
free_memory = quantity.parse_quantity(allocatable['memory']) - used_memory
free_gpu = int(allocatable.get('nvidia.com/gpu', 0)) - used_gpu
node_info = {
'name': node_name,
'cpu': free_cpu,
'memory': free_memory,
'gpu': free_gpu,
}
if node.metadata.labels:
node_info['node_labels'] = node.metadata.labels
nodes_info[node_name] = node_info
logging.info(
'Node: %s, CPU: %s, Memory: %s, GPU: %s, Topology: %s',
node_name,
free_cpu,
free_memory,
free_gpu,
node_topology_key(node_info),
)
return nodes_info
def find_pod_gates(
pods: list[kubernetes.client.models.V1Pod], prefix: str
) -> set[str]:
"""Finds pods with scheduling gates that starts with the prefix.
Args:
pods: A list of V1Pod objects.
prefix: The prefix of the scheduling gate.
Returns:
A set of scheduling gate names.
"""
s = set()
for pod in pods:
if pod.spec.scheduling_gates:
for g in pod.spec.scheduling_gates:
if g.name.startswith(prefix):
s.add(g.name)
return s
def find_schedulable_pods(
job_name: str,
pods: list[kubernetes.client.models.V1Pod],
) -> dict[str, dict[str, Any]]:
"""Finds pods that can be scheduled for a given gate.
Args:
job_name: The name of the job.
pods: A list of V1Pod objects.
Returns:
A dict of pod names to pod infos for schedulable pods.
"""
pods_to_schedule = {}
for pod in pods:
pod_name = pod.metadata.name
pod_index = None
if JOB_COMPLETION_INDEX_LABEL_KEY in pod.metadata.labels:
pod_index = pod.metadata.labels[JOB_COMPLETION_INDEX_LABEL_KEY]
elif KUBEFLOW_REPLICA_INDEX_LABEL_KEY in pod.metadata.labels:
pod_index = pod.metadata.labels[KUBEFLOW_REPLICA_INDEX_LABEL_KEY]
else:
# it's not hard stop, we still can order pods by name
logging.info(
'No index in pod %s metadata. Ordering derived from name.',
pod_name,
)
used_cpu, used_memory, used_gpu = 0, 0, 0
for container in pod.spec.containers:
requests = container.resources.requests or {}
used_cpu += quantity.parse_quantity(requests.get('cpu', 0))
used_memory += quantity.parse_quantity(requests.get('memory', 0))
used_gpu += int(requests.get('nvidia.com/gpu', 0))
pod_info = {
'name': pod_name,
'namespace': pod.metadata.namespace,
'index': pod_index,
'cpu': used_cpu,
'memory': used_memory,
'gpu': used_gpu,
'spec': pod.spec,
'metadata': pod.metadata,
'job_name': job_name,
'creation_time': pod.metadata.creation_timestamp,
}
if pod.spec.node_selector:
pod_info['node_selector'] = pod.spec.node_selector
pods_to_schedule[pod_name] = pod_info
logging.info(
'Found schedulable pod: %s/%s, CPU: %s, Memory: %s, GPU: %s Index: %s',
pod.metadata.namespace,
pod_name,
used_cpu,
used_memory,
used_gpu,
pod_index,
)
return pods_to_schedule
def can_schedule(node: dict[str, Any], pod: dict[str, Any]) -> bool:
"""Checks if a given pod can be scheduled on a given node.
The decision is based on resource node availability and pod requriements. The
node_selector is also checked, if the pod has it.
Args:
node: The node to check if the pod can be scheduled on.
pod: The pod to check if it can be scheduled on the node.
Returns:
True if the pod can be scheduled on the node, False otherwise.
"""
node_selector = pod.get('node_selector', {})
node_labels = node.get('node_labels', {})
for key, value in node_selector.items():
if key not in node_labels or node_labels[key] != value:
return False
return (
node['cpu'] >= pod['cpu']
and node['memory'] >= pod['memory']
and node['gpu'] >= pod['gpu']
)
def schedule_pod_on_node(
v1: kubernetes.client.CoreV1Api,
pod_name: str,
pod_namespace: str,
node: dict[str, Any],
gate_name: str,
) -> bool:
"""Schedules a pod on a given node using affinity for direct assignment.
Args:
v1: The kubernetes client.
pod_name: The name of the pod to schedule.
pod_namespace: The namespace of the pod to schedule.
node: The node to schedule the pod on.
gate_name: The name of the gate to remove from the pod.
Returns:
True if the pod was scheduled on the node, False otherwise.
"""
try:
pod = v1.read_namespaced_pod(pod_name, pod_namespace)
if any(gate.name == gate_name for gate in pod.spec.scheduling_gates):
new_gates = [
gate for gate in pod.spec.scheduling_gates if gate.name != gate_name
]
pod.spec.affinity = {
'nodeAffinity': {
'requiredDuringSchedulingIgnoredDuringExecution': {
'nodeSelectorTerms': [{
'matchExpressions': [{
'key': 'kubernetes.io/hostname',
'operator': 'In',
'values': [node['name']],
}]
}]
}
}
}
pod.spec.scheduling_gates = new_gates
v1.replace_namespaced_pod(pod_name, pod_namespace, pod)
logging.info(
'Pod %s/%s scheduled on %s with topology %s', pod_namespace, pod_name, node['name'], node_topology_key(node)
)
except kubernetes.client.rest.ApiException as e:
logging.exception(
'Exception when removing pod %s scheduling gate: %s', pod_name, e
)
return False
return True
def calculate_pods_assignment(
sorted_nodes: list[dict[str, Any]], sorted_pods: list[dict[str, Any]]
) -> list[int]:
"""Gets the best pod assignment by minimizing the topology distance.
Args:
sorted_nodes: A list of sorted node infos.
sorted_pods: A list of sorted pod infos.
Returns:
A list of node indices that the pods should be assigned to.
"""
assignment = [-i for i in reversed(range(1, len(sorted_pods) + 1))]
best_assignment = []
minimum_distance = float('inf')
while True:
all_ok = True
i = len(assignment) - 1
while i >= 0 and all_ok:
assignment[i] += 1
if assignment[i] == len(sorted_nodes):
break # no more nodes to consider
if assignment[i] >= 0 and can_schedule(
sorted_nodes[assignment[i]], sorted_pods[i]
):
i -= 1
elif i < len(assignment) - 1 and assignment[i] == assignment[i + 1] - 1:
all_ok = False # to ignore the half of triangle of all combinations
# we checked all combinations, return what found so far
if assignment[-1] == len(sorted_nodes):
break
# calculate total distance of the detected viable assignment
if all_ok:
new_distance = 0
for i in range(1, len(sorted_pods)):
new_distance += node_topology_distance(
sorted_nodes[assignment[i]], sorted_nodes[assignment[i - 1]]
)
if new_distance < minimum_distance:
best_assignment = assignment.copy()
minimum_distance = new_distance
return best_assignment
def list_pods(
v1: kubernetes.client.CoreV1Api, state_filter: str = None
) -> list[kubernetes.client.models.V1Pod]:
"""Lists pods in all namespaces.
Args:
v1: The kubernetes client.
state_filter: The pod state to filter by.
Returns:
A list of V1Pod objects.
"""
pods = []
for n in v1.list_namespace().items:
namespace = n.metadata.name
for pod in v1.list_namespaced_pod(namespace).items:
if not state_filter or pod.status.phase == state_filter:
pods.append(pod)
return pods
def schedule_pod_with_gate(
v1: kubernetes.client.CoreV1Api,
gate_name: str,
) -> None:
"""Find and schedule pods with a given gate.
Args:
v1: k8s client
gate_name: name of the gate to schedule pods with (see k8s yaml, eg
schedulingGates=gke.io/topology-aware-auto-*`)
"""
# query the pods again after the sleep, just in case not all gated pods
# are returned from previous query
pending_pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Pending')
pods_for_given_gate: list[kubernetes.client.models.V1Pod] = []
for p in pending_pods:
if p.spec.scheduling_gates:
if gate_name in {g.name for g in p.spec.scheduling_gates}:
pods_for_given_gate.append(p)
pods_per_job: dict[str, list[kubernetes.client.models.V1Pod]] = (
collections.defaultdict(list)
)
jobless_pod_names: set[str] = set(
[p.metadata.name for p in pods_for_given_gate]
)
for job_name_extractor in [
extract_job_name_label,
extract_kubeflow_job_name_label,
extract_owner_reference_uid,
extract_helm_job_name_label,
]:
for job_name, job_pod_group in itertools.groupby(
pods_for_given_gate,
job_name_extractor,
):
# ignore pod groups with not yet recognized job name
# (we have 3 iters/different ways to extract it, see outer loop)
if job_name == UNKNOWN_JOB_NAME:
continue
pod_group = list(job_pod_group)
# sanity check: all pods have creation_timestamp
if not all(p.metadata.creation_timestamp for p in pod_group):
logging.error(
'No pod creation_timestamp in job %s. Job ignored.Pods: %s',
job_name,
', '.join([p.metadata.name for p in pod_group]),
)
# exclude such bad pods from consideration
jobless_pod_names -= {p.metadata.name for p in pod_group}
continue
# sanity check: all pods have same tolerations
if not all_pods_have_same_tolerations(pod_group):
logging.error(
'Pods in job %s have different tolerations. Job ignored.Pods: %s',
job_name,
', '.join([p.metadata.name for p in pod_group]),
)
# exclude such bad pods from consideration
jobless_pod_names -= {p.metadata.name for p in pod_group}
continue
pod_names_in_group = set([p.metadata.name for p in pod_group])
# sanity check: if all job pods has been scheduled, we can ignore the job
if all(p not in jobless_pod_names for p in pod_names_in_group):
logging.info(
'All %d pods of job %s are already scheduled. Job ignored.',
len(pod_group),
job_name,
)
continue
# all checks passed
jobless_pod_names -= pod_names_in_group
pods_per_job[job_name] = list(pod_group)
logging.info(
'Found %d pods for job %s, created: %s',
len(pods_per_job[job_name]),
job_name,
pods_per_job[job_name][0].metadata.creation_timestamp,
)
# sanity check
if jobless_pod_names:
logging.warning(
'Found %d pods without explicit job name, going to schedule all'
' together. Pods: %s',
len(jobless_pod_names),
', '.join(list(jobless_pod_names)),
)
# only for robustness we try to schedule them separately
pods_per_job['pods-without-explicit-job-name'] = [
p for p in pods_for_given_gate if p.metadata.name in jobless_pod_names
]
logging.info('Start scheduling %d jobs', len(pods_per_job))
nodes: list[kubernetes.client.models.V1Node] = v1.list_node().items
# Sort job names based on the creation time of the first pod in each job group
sorted_job_names = sorted(
pods_per_job.keys(),
key=lambda job_name: pods_per_job[job_name][
0
].metadata.creation_timestamp,
)
running_pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Running')
logging.info('Already running pods number: %d', len(running_pods))
recently_scheduled_nodes_names = set()
for job_name in sorted_job_names:
job_pods: list[kubernetes.client.models.V1Pod] = pods_per_job.get(
job_name, []
)
logging.info(
'Attempting to schedule job: %s with %s pods ', job_name, len(job_pods)
)
try:
schedulable_pods_infos: dict[str, Any] = find_schedulable_pods(
job_name, job_pods
)
# filter out nodes that are already occupied by running pods
nodes = [
node
for node in nodes
if node.metadata.name not in recently_scheduled_nodes_names
]
nodes_infos: dict[str, Any] = find_schedulable_nodes(
nodes, running_pods, schedulable_pods_infos
)
if len(schedulable_pods_infos) > len(nodes_infos):
logging.error(
'Not enough nodes available for job %s scheduling: %s nodes, %s'
' pods. Skipping job.',
job_name,
len(nodes_infos),
len(schedulable_pods_infos),
)
continue
logging.info(
'Available nodes for job %s scheduling: %s',
job_name,
len(nodes_infos),
)
sorted_pods = sorted(schedulable_pods_infos.values(), key=pod_sorting_key)
sorted_nodes = sorted(nodes_infos.values(), key=node_topology_key)
best_assignment = calculate_pods_assignment(sorted_nodes, sorted_pods)
if not best_assignment:
logging.error(
'No scheduling for job %s with gate %s was found. Skipping job.',
job_name,
gate_name,
)
continue
logging.info(
'Assignment found, scheduling %s with %s pods.',
job_name,
len(job_pods),
)
for i, pod in enumerate(sorted_pods):
node = sorted_nodes[best_assignment[i]]
if not schedule_pod_on_node(
v1, pod['name'], pod['namespace'], node, gate_name
):
logging.error(
'Failed to schedule pod %s on node %s. Skipping job %s',
pod['name'],
node['name'],
job_name,
)
break
# revisit: in case of failure clear up partially scheduled pods
recently_scheduled_nodes_names.add(node['name'])
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception when scheduling %s job,gate %s: %s', job_name, gate_name, e
)
def run_scheduling_loop() -> None:
"""Runs scheduling.
This function runs a infinite loop that periodically schedules pods with
topology-aware scheduling gates.
"""
parser = argparse.ArgumentParser(prog='schedule-workload.py')
parser.add_argument(
'-g', '--gate', default='gke.io/topology-aware-auto-'
) # prefix of the schedule gate
parser.add_argument(
'-i', '--interval', default=1.0
) # intervals (in seconds) between scheduling
parser.add_argument(
'--ignored-namespace', nargs='*', default=[]
) # namespace to search for pods
args = parser.parse_args()
try:
kubernetes.config.load_incluster_config()
except kubernetes.config.ConfigException:
kubernetes.config.load_kube_config()
v1 = kubernetes.client.CoreV1Api()
# wait needed during container restart to allow previously scheduled pods
# to be visible on nodes and occupy resources for their correct estimates
logging.info('[Cool off] 90sec')
time.sleep(90.0)
try:
last_run_ts = time.time()
while True:
time_since_prev_run = time.time() - last_run_ts
if time_since_prev_run < args.interval:
logging.info('[Cool off] %ssec', args.interval - time_since_prev_run)
time.sleep(args.interval - time_since_prev_run)
last_run_ts = time.time()
# Get pods to schedule
pods: list[kubernetes.client.models.V1Pod] = list_pods(v1, 'Pending')
gates = find_pod_gates(pods, args.gate)
logging.info('Found %s pending pods and %s gates', len(pods), len(gates))
if not gates:
# No pods to be scheduled
continue
# sleep for 5 seconds, assuming that all pods within one group would be
# all visible by then
logging.info('[Cool off] 5sec')
time.sleep(5.0)
for g in gates:
logging.info('Scheduling pods with gate %s', g)
schedule_pod_with_gate(v1, g)
logging.info('[Cool off] 60sec')
time.sleep(60.0) # cool off
except kubernetes.client.rest.ApiException as e:
logging.exception('Exception when listing Kubernetes nodes or pods: %s', e)
if __name__ == '__main__':
run_scheduling_loop()