func()

in ray-on-gke/tpu/kuberay-tpu-webhook/main.go [597:743]


func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) {
	pod, err := extractPod(admissionReview)
	if err != nil {
		return nil, err
	}

	var patches []patch
	admissionResponse := &admissionv1.AdmissionResponse{
		UID:     admissionReview.Request.UID,
		Allowed: true,
	}

	containers := pod.Spec.Containers
	if containers == nil {
		return nil, errors.New("Container path not specified")
	}
	if !containerRequestingTPUs(containers...) {
		// if no TPUs are requested, simply admit the Pod
		return admissionResponse, nil
	}

	// ray operator only sets GenerateName field - doesn't include random suffix until after admission request
	// use mapping of {cluster name, group name, replicaIndex} -> workers to extract next TPU_WORKER_ID
	clusterName := pod.Labels["ray.io/cluster"]
	if clusterName == "" {
		return nil, errors.New("Ray Pod created by KubeRay missing RayCluster label")
	}
	groupName := pod.Labels["ray.io/group"]
	if groupName == "" {
		return nil, errors.New("Ray Pod created by KubeRay missing Group label")
	}
	namespace := pod.Namespace
	topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
	if topology == "" {
		return nil, errors.New("Ray Pod created by KubeRay missing TPU topology nodeSelector")
	}
	// assign worker to the next unique ID in the Pod Slice and update map
	chipsPerHost := getNumTPUChipsRequested(containers...)
	numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, chipsPerHost) // ignore error here because topology may not be set yet

	// Wait for PodInformer cache to update from previous requests or timeout
	if waitTimeout(&t.wg, time.Second*1) {
		klog.V(1).Info("MutatePod", "PodInformer AddFunc called for prior admission request")
	} else {
		klog.V(1).Info("MutatePod", "Timed out waiting for PodInformer AddFunc")
	}
	// Add 1 to the WaitGroup to represent the pending Pod to the cache
	defer t.wg.Add(1)
	t.waiting += 1

	// query k8s client to populate sliceToWorkerIDs to then calculate the next TPU_WORKER_ID and replicaIndex
	sliceToWorkerIDs, err := t.getSliceToWorkerIDs(clusterName, groupName, namespace, numOfHosts)
	if err != nil {
		return nil, err
	}
	replicaIndex := getReplicaIndex(sliceToWorkerIDs, clusterName, groupName, namespace)
	podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
	tpuWorkerID, err := getNextWorkerID(sliceToWorkerIDs, podSlice, namespace, replicaIndex) // defaults to 0 for single-host
	if err != nil {
		return nil, err
	}
	// set the unique identifier for the last admitted Pod by this TPUWebhookServer
	t.lastAdmitted = fmt.Sprintf("%s-%s-%d-%d", namespace, clusterName, replicaIndex, tpuWorkerID)

	// inject replica index label
	injectReplicaLabel(clusterName, namespace, replicaIndex, groupName, &patches)

	if numOfHosts > 1 {
		// inject hostname into pod spec for DNS records
		hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
		klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "hostname", hostname)
		hostnamePatch := patch{"op": "add"}
		hostnamePatch["path"] = "/spec/hostname"
		hostnamePatch["value"] = hostname
		patches = append(patches, hostnamePatch)

		// inject pod affinity/anti-affinity for scheduling
		injectPodAffinity(pod, replicaIndex, groupName, &patches)
	}

	// inject all environment variables into the container requesting TPUs
	for i := 0; i < len(containers); i++ {
		container := containers[i]
		if containerRequestingTPUs(container) {
			path := fmt.Sprintf("/spec/containers/%d/env", i)
			if numOfHosts > 1 {
				// inject TPU_WORKER_HOSTNAMES
				hostnames, err := genDNSHostnames(numOfHosts, groupName, clusterName, namespace, replicaIndex)
				if err != nil {
					return nil, err
				}
				klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
				klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
				injectHostnames(clusterName, hostnames, path, container, &patches)
			}
			// inject TPU_WORKER_ID
			if getEnvironmentVariable("TPU_WORKER_ID", container) == "" {
				klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_ID", tpuWorkerID, "Replica Index", replicaIndex)
				workerID := corev1.EnvVar{
					Name:  "TPU_WORKER_ID",
					Value: fmt.Sprint(tpuWorkerID),
				}
				idPatch := patch{"op": "add"}
				// create new EnvVar array if container.Env is empty, and append new EnvVars if not
				if len(container.Env) == 0 {
					idPatch["path"] = path
					idPatch["value"] = []corev1.EnvVar{workerID}
				} else {
					idPatch["path"] = fmt.Sprintf("%s/-", path)
					idPatch["value"] = workerID
				}
				patches = append(patches, idPatch)
			}
			// inject TPU_NAME
			if getEnvironmentVariable("TPU_NAME", container) == "" {
				tpuNameValue := fmt.Sprintf("%s-%d", groupName, replicaIndex)
				klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_NAME", tpuNameValue, "Replica Index", replicaIndex)
				tpuName := corev1.EnvVar{
					Name:  "TPU_NAME",
					Value: tpuNameValue,
				}
				namePatch := patch{"op": "add"}
				// create new EnvVar array if container.Env is empty, and append new EnvVars if not
				if len(container.Env) == 0 {
					namePatch["path"] = path
					namePatch["value"] = []corev1.EnvVar{tpuName}
				} else {
					namePatch["path"] = fmt.Sprintf("%s/-", path)
					namePatch["value"] = tpuName
				}
				patches = append(patches, namePatch)
			}
		}
	}

	patchBytes, err := json.Marshal(patches)
	if err != nil {
		return nil, err
	}

	admissionResponse.Patch = patchBytes
	admissionResponse.PatchType = func() *admissionv1.PatchType {
		pt := admissionv1.PatchTypeJSONPatch
		return &pt
	}()
	return admissionResponse, nil
}