tpu-provisioner/admission_controller/admission_controller.py (73 lines of code) (raw):
#!/usr/bin/env python3
import os
import json
import base64
import logging
import hashlib
from fastapi import FastAPI, Body
from jsonpatch import JsonPatch
from copy import deepcopy
app = FastAPI()
webhook_logger = logging.getLogger(__name__)
webhook_logger.setLevel(logging.INFO)
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s")
# environment variables
LOCATION_HINT = "RESERVATION_LOCATION_HINT"
ALWAYS_HINT_TIME = "ALWAYS_HINT_TIME"
FORCE_ON_DEMAND = "FORCE_ON_DEMAND"
# labels
job_key_label = "job-key"
reservation_name_label = "cloud.google.com/reservation-name"
gke_spot_label = "cloud.google.com/gke-spot"
gke_location_hint_label = "cloud.google.com/gke-location-hint"
# API endpoint
@app.post("/mutate")
def mutate_request(request: dict = Body(...)):
'''API endpoint for the admission controller mutating webhook.'''
uid: str = request["request"]["uid"]
object_in: dict = request["request"]["object"]
webhook_logger.info(f'Patching {object_in["kind"]} {object_in["metadata"]["namespace"]}/{object_in["metadata"]["name"]}')
response: dict = admission_review(uid, object_in)
webhook_logger.info(f'Response: {json.dumps(response)}')
return response
def admission_review(uid: str, object_in: dict) -> dict:
'''Returns an AdmissionReview JSONPatch for the given AdmissionRequest.'''
return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {
"uid": uid,
"allowed": True,
"patchType": "JSONPatch",
"status": {"message": f"Patched {object_in['kind']}: {object_in['metadata']['namespace']}/{object_in['metadata']['name']}"},
"patch": patch(object_in),
},
}
def patch(object_in: dict) -> str:
'''Returns a base64 encoded patch for the given k8s object.'''
patches: list[dict] = make_patches(object_in)
return base64.b64encode(str(patches).encode()).decode()
def make_patches(object_in: dict) -> JsonPatch:
'''Generates a JsonPatch for Job mutations that are based on environment variables.'''
job_name: str = object_in["metadata"]["name"]
job_namespace: str = object_in["metadata"]["namespace"]
modified_object: dict = deepcopy(object_in)
if "nodeSelector" not in modified_object["spec"]["template"]["spec"]:
modified_object["spec"]["template"]["spec"]["nodeSelector"] = {}
# Add job-key node selector unconditionally.
modified_object["spec"]["template"]["spec"]["nodeSelector"][job_key_label] = job_key_value(job_name, job_namespace)
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {job_key_label}: {job_key_value(job_name, job_namespace)}')
if os.environ.get(FORCE_ON_DEMAND) == "true":
# Remove reservation label if FORCE_ON_DEMAND is set.
if reservation_name_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]:
del modified_object["spec"]["template"]["spec"]["nodeSelector"][reservation_name_label]
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {reservation_name_label}')
# Remove spot label if FORCE_ON_DEMAND is set.
if gke_spot_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]:
del modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_spot_label]
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {gke_spot_label}')
# Set location hint nodeSelector if RESERVATION_LOCATION_HINT is set.
location_hint_value: str = os.environ.get(LOCATION_HINT, "")
if location_hint_value != "":
modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_location_hint_label] = location_hint_value
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {gke_location_hint_label}: {location_hint_value}')
patch: JsonPatch = JsonPatch.from_diff(object_in, modified_object)
return patch
def job_key_value(job_name: str, job_namespace: str) -> str:
'''Returns the SHA1 hash of the namespaced Job name.'''
return sha1(f'{job_namespace}/{job_name}')
def sha1(data: str) -> str:
'''Returns the SHA1 hash of the given string.'''
return hashlib.sha1(data.encode()).hexdigest()