in ray-on-gke/tpu/kuberay-tpu-webhook/main.go [597:743]
func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) {
pod, err := extractPod(admissionReview)
if err != nil {
return nil, err
}
var patches []patch
admissionResponse := &admissionv1.AdmissionResponse{
UID: admissionReview.Request.UID,
Allowed: true,
}
containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Container path not specified")
}
if !containerRequestingTPUs(containers...) {
// if no TPUs are requested, simply admit the Pod
return admissionResponse, nil
}
// ray operator only sets GenerateName field - doesn't include random suffix until after admission request
// use mapping of {cluster name, group name, replicaIndex} -> workers to extract next TPU_WORKER_ID
clusterName := pod.Labels["ray.io/cluster"]
if clusterName == "" {
return nil, errors.New("Ray Pod created by KubeRay missing RayCluster label")
}
groupName := pod.Labels["ray.io/group"]
if groupName == "" {
return nil, errors.New("Ray Pod created by KubeRay missing Group label")
}
namespace := pod.Namespace
topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
if topology == "" {
return nil, errors.New("Ray Pod created by KubeRay missing TPU topology nodeSelector")
}
// assign worker to the next unique ID in the Pod Slice and update map
chipsPerHost := getNumTPUChipsRequested(containers...)
numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, chipsPerHost) // ignore error here because topology may not be set yet
// Wait for PodInformer cache to update from previous requests or timeout
if waitTimeout(&t.wg, time.Second*1) {
klog.V(1).Info("MutatePod", "PodInformer AddFunc called for prior admission request")
} else {
klog.V(1).Info("MutatePod", "Timed out waiting for PodInformer AddFunc")
}
// Add 1 to the WaitGroup to represent the pending Pod to the cache
defer t.wg.Add(1)
t.waiting += 1
// query k8s client to populate sliceToWorkerIDs to then calculate the next TPU_WORKER_ID and replicaIndex
sliceToWorkerIDs, err := t.getSliceToWorkerIDs(clusterName, groupName, namespace, numOfHosts)
if err != nil {
return nil, err
}
replicaIndex := getReplicaIndex(sliceToWorkerIDs, clusterName, groupName, namespace)
podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
tpuWorkerID, err := getNextWorkerID(sliceToWorkerIDs, podSlice, namespace, replicaIndex) // defaults to 0 for single-host
if err != nil {
return nil, err
}
// set the unique identifier for the last admitted Pod by this TPUWebhookServer
t.lastAdmitted = fmt.Sprintf("%s-%s-%d-%d", namespace, clusterName, replicaIndex, tpuWorkerID)
// inject replica index label
injectReplicaLabel(clusterName, namespace, replicaIndex, groupName, &patches)
if numOfHosts > 1 {
// inject hostname into pod spec for DNS records
hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "hostname", hostname)
hostnamePatch := patch{"op": "add"}
hostnamePatch["path"] = "/spec/hostname"
hostnamePatch["value"] = hostname
patches = append(patches, hostnamePatch)
// inject pod affinity/anti-affinity for scheduling
injectPodAffinity(pod, replicaIndex, groupName, &patches)
}
// inject all environment variables into the container requesting TPUs
for i := 0; i < len(containers); i++ {
container := containers[i]
if containerRequestingTPUs(container) {
path := fmt.Sprintf("/spec/containers/%d/env", i)
if numOfHosts > 1 {
// inject TPU_WORKER_HOSTNAMES
hostnames, err := genDNSHostnames(numOfHosts, groupName, clusterName, namespace, replicaIndex)
if err != nil {
return nil, err
}
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
injectHostnames(clusterName, hostnames, path, container, &patches)
}
// inject TPU_WORKER_ID
if getEnvironmentVariable("TPU_WORKER_ID", container) == "" {
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_ID", tpuWorkerID, "Replica Index", replicaIndex)
workerID := corev1.EnvVar{
Name: "TPU_WORKER_ID",
Value: fmt.Sprint(tpuWorkerID),
}
idPatch := patch{"op": "add"}
// create new EnvVar array if container.Env is empty, and append new EnvVars if not
if len(container.Env) == 0 {
idPatch["path"] = path
idPatch["value"] = []corev1.EnvVar{workerID}
} else {
idPatch["path"] = fmt.Sprintf("%s/-", path)
idPatch["value"] = workerID
}
patches = append(patches, idPatch)
}
// inject TPU_NAME
if getEnvironmentVariable("TPU_NAME", container) == "" {
tpuNameValue := fmt.Sprintf("%s-%d", groupName, replicaIndex)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_NAME", tpuNameValue, "Replica Index", replicaIndex)
tpuName := corev1.EnvVar{
Name: "TPU_NAME",
Value: tpuNameValue,
}
namePatch := patch{"op": "add"}
// create new EnvVar array if container.Env is empty, and append new EnvVars if not
if len(container.Env) == 0 {
namePatch["path"] = path
namePatch["value"] = []corev1.EnvVar{tpuName}
} else {
namePatch["path"] = fmt.Sprintf("%s/-", path)
namePatch["value"] = tpuName
}
patches = append(patches, namePatch)
}
}
}
patchBytes, err := json.Marshal(patches)
if err != nil {
return nil, err
}
admissionResponse.Patch = patchBytes
admissionResponse.PatchType = func() *admissionv1.PatchType {
pt := admissionv1.PatchTypeJSONPatch
return &pt
}()
return admissionResponse, nil
}