ray-on-gke/tpu/kuberay-tpu-webhook/main.go (718 lines of code) (raw):
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"encoding/base64"
"encoding/json"
"errors"
"flag"
"fmt"
"math"
"net/http"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
ray "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
admissionv1 "k8s.io/api/admission/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
listersv1 "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
)
// slice represents a TPU Pod Slice.
type slice struct {
clusterName string
groupName string
namespace string
replicaIndex int
numOfHosts int32
}
// TPUWebhookServer is a KubeRay TPU webhook server instance.
type TPUWebhookServer struct {
// podLister is used to query Pods from an informer cache.
podLister listersv1.PodLister
cacheMutex sync.Mutex
wg sync.WaitGroup
waiting int
lastAdmitted string
}
// patch is a JSON patch describing mutate operation(s) for an incoming object.
type patch map[string]any
var (
certPath = "/etc/kuberay-tpu-webhook/tls/tls.crt"
keyPath = "/etc/kuberay-tpu-webhook/tls/tls.key"
tpuResourceName = corev1.ResourceName("google.com/tpu")
// Flag arguments.
BindAddr string
CACert string
KubeConfigPath string
ServerCert string
ServerKey string
)
func NewTPUWebhookServer(podLister listersv1.PodLister) *TPUWebhookServer {
return &TPUWebhookServer{
podLister: podLister,
}
}
// Mutate handles http Request for Pod creation and writes a response
func (t *TPUWebhookServer) Mutate(w http.ResponseWriter, r *http.Request) {
t.cacheMutex.Lock()
defer t.cacheMutex.Unlock()
admissionReview := &admissionv1.AdmissionReview{}
if err := json.NewDecoder(r.Body).Decode(admissionReview); err != nil {
http.Error(w, "Error decoding request body", http.StatusBadRequest)
w.WriteHeader(http.StatusBadRequest)
return
}
if admissionReview.Request == nil || admissionReview.Request.Kind.Kind != "Pod" {
http.Error(w, "Invalid Kind", http.StatusBadRequest)
w.WriteHeader(http.StatusBadRequest)
return
}
klog.V(0).InfoS("Mutate", "Received review for Pod creation: %s", admissionReview.Request.Name)
response, err := t.mutatePod(admissionReview)
if err != nil {
klog.Errorf("Failed to mutate Pod: %s", err)
http.Error(w, "Failed to mutate Pod", http.StatusForbidden)
w.WriteHeader(http.StatusForbidden)
return
}
admissionReview.Response = response
responseBytes, err := json.Marshal(admissionReview)
if err != nil {
klog.Errorf("Failed to encode response: %s", err)
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprint(w, string(responseBytes))
}
// Validate handles http Request for RayCluster creation and writes a response
func (t *TPUWebhookServer) Validate(w http.ResponseWriter, r *http.Request) {
admissionReview := &admissionv1.AdmissionReview{}
if err := json.NewDecoder(r.Body).Decode(admissionReview); err != nil {
http.Error(w, "Error decoding request body", http.StatusBadRequest)
w.WriteHeader(http.StatusBadRequest)
return
}
if admissionReview.Request == nil || admissionReview.Request.Kind.Kind != "RayCluster" {
http.Error(w, "Invalid Kind", http.StatusBadRequest)
w.WriteHeader(http.StatusBadRequest)
return
}
klog.V(0).InfoS("Validate", "Received review for RayCluster creation: %s", admissionReview.Request.Name)
response, err := validateRayCluster(admissionReview)
if err != nil {
klog.Errorf("Failed to validate RayCluster: %s", err)
http.Error(w, "Failed to validate RayCluster", http.StatusForbidden)
w.WriteHeader(http.StatusForbidden)
return
}
admissionReview.Response = response
responseBytes, err := json.Marshal(admissionReview)
if err != nil {
klog.Errorf("Failed to encode response: %s", err)
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprint(w, string(responseBytes))
}
// printSliceToWorkerIds logs sliceToWorkerIDs contents for debugging
func printSliceToWorkerIds(sliceToWorkerIDs map[slice][]int) {
for slice, workerList := range sliceToWorkerIDs {
klog.V(1).InfoS("printSliceToWorkerIds", "RayCluster", slice.namespace+"/"+slice.clusterName, "Worker Group", slice.groupName)
for _, workerID := range workerList {
klog.V(1).InfoS("printSliceToWorkerIds", "RayCluster", slice.namespace+"/"+slice.clusterName, "Worker ID", workerID)
}
}
}
// containerRequestingTPUs returns whether containers are requesting TPU resources
func containerRequestingTPUs(containers ...corev1.Container) bool {
for _, container := range containers {
if l := container.Resources.Limits; l != nil {
if resource := l[tpuResourceName]; !resource.IsZero() {
return true
}
}
if r := container.Resources.Requests; r != nil {
if resource := r[tpuResourceName]; !resource.IsZero() {
return true
}
}
}
return false
}
// getNumTPUChipsRequested returns `google.com/TPU` Resource request value for the container
// this indicates the number of TPU chips for the container to use
func getNumTPUChipsRequested(containers ...corev1.Container) int64 {
tpuLimit := int64(0)
tpuRequest := int64(0)
for _, container := range containers {
if l := container.Resources.Limits; l != nil {
if resource := l[tpuResourceName]; !resource.IsZero() {
tpuLimit = resource.Value()
}
}
if r := container.Resources.Requests; r != nil {
if resource := r[tpuResourceName]; !resource.IsZero() {
tpuRequest = resource.Value()
}
} else {
// default to limit if request is ommitted
tpuRequest = tpuLimit
}
}
return min(tpuLimit, tpuRequest)
}
// getNumTPUHostsFromTopology returns number of TPU VM hosts in Pod Slice specified by gke-tpu-topology Pod nodeSelector
func getNumTPUHostsFromTopology(clusterName string, groupName string, namespace string, topology string, chipsPerHost int64) (int32, error) {
if topology == "" {
return 0, errors.New("TPU topology not specified")
}
topologyVals := strings.Split(topology, "x")
chips := 1
for i := 0; i < len(topologyVals); i++ {
dim, err := strconv.Atoi(topologyVals[i])
if err != nil {
klog.ErrorS(err, "getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "gke-tpu-topology", topology)
return 0, err
}
chips *= dim
}
// calculate the # of VMs using # of chips per host
hosts := max(int32(chips)/int32(chipsPerHost), 1)
klog.V(1).InfoS("getNumTPUHostsFromTopology", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "topology", topology, "chips", chips, "hosts", hosts)
return hosts, nil
}
// extractRayCluster returns RayCluster unmarshalled from an admission request
func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCluster, error) {
if admissionReview.Request.Kind.Kind != "RayCluster" {
return nil, fmt.Errorf("Expected RayCluster but got %s", admissionReview.Request.Kind.Kind)
}
rayCluster := ray.RayCluster{}
if err := json.Unmarshal(admissionReview.Request.Object.Raw, &rayCluster); err != nil {
return nil, err
}
return &rayCluster, nil
}
// generateHeadlessServiceName returns the expected TPU headless service name for a RayCluster
func generateHeadlessServiceName(clusterName string) string {
serviceName := fmt.Sprintf("%s-%s", clusterName, utils.HeadlessServiceSuffix)
// Apply the same truncation as in the RayCluster controller when generating the headless service
// name. This is to maintain the up-to 63 char compatibility guarantee for hostnames (RFC 1123).
return utils.CheckName(serviceName)
}
// genDNSHostnames returns list of DNS hostnames for TPU VM hosts as a string
func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) {
if numOfHosts == 0 {
err := errors.New("workerGroupSpec NumOfHosts not set")
return "", err
}
headlessServiceName := generateHeadlessServiceName(clusterName)
hostNames := make([]string, numOfHosts)
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.{CLUSTER_NAME}-headless
for j := 0; j < int(numOfHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", groupName, replicaIndex, j, headlessServiceName)
}
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex)
return strings.Join(hostNames, ","), nil
}
// injectHostnames injects subdomain and TPU_WORKER_HOSTNAMES into a Pod for TPU multi-host initialization
func injectHostnames(clusterName string, hostNames string, envPath string, container corev1.Container, patches *[]patch) {
subdomainPatch, hostNamesPatch := patch{"op": "add"}, patch{"op": "add"}
subdomainPath := "/spec/subdomain"
tpuWorkerHostNames := corev1.EnvVar{
Name: "TPU_WORKER_HOSTNAMES",
Value: hostNames,
}
subdomainPatch["path"] = subdomainPath
subdomainPatch["value"] = generateHeadlessServiceName(clusterName)
// create new EnvVar array if container.Env is empty, and append hostnames if not
if len(container.Env) == 0 {
hostNamesPatch["path"] = envPath
hostNamesPatch["value"] = []corev1.EnvVar{tpuWorkerHostNames}
} else {
hostNamesPatch["path"] = fmt.Sprintf("%s/-", envPath)
hostNamesPatch["value"] = tpuWorkerHostNames
}
*patches = append(*patches, subdomainPatch, hostNamesPatch)
}
// injectReplicaLabel injects replicaIndex label into a Pod for TPU Pod scheduling and Ray multi-host autoscaling
func injectReplicaLabel(clusterName string, namespace string, replicaIndex int, workerGroupName string, patches *[]patch) {
labelPatch := patch{"op": "replace"}
labelPath := "/metadata/labels/replicaIndex"
replicaLabelValue := workerGroupName + "-" + strconv.Itoa(replicaIndex)
klog.V(1).InfoS("injectReplicaLabel", "RayCluster", namespace+"/"+clusterName, "replicaIndex", replicaLabelValue)
labelPatch["path"] = labelPath
labelPatch["value"] = replicaLabelValue
*patches = append(*patches, labelPatch)
}
// injectPodAffinity injects pod affinity and anti-affinity scheduling constraints using replicaIndex label
func injectPodAffinity(pod *corev1.Pod, replicaIndex int, workerGroupName string, patches *[]patch) {
key := "replicaIndex"
value := workerGroupName + "-" + strconv.Itoa(replicaIndex)
topologyKey := "cloud.google.com/gke-nodepool"
clusterName := pod.Labels["ray.io/cluster"]
namespace := pod.Namespace
klog.V(1).InfoS("injectPodAffinity", "RayCluster", namespace+"/"+clusterName, "podAffinity match label", value)
// construct affinity value to inject - schedule pods with the same replicaIndex together
podAffinityPatch := patch{"op": "add"}
affinitySelectorRequirement := metav1.LabelSelectorRequirement{Key: key, Operator: metav1.LabelSelectorOpIn, Values: []string{value}}
affinityMatchExpressions := []metav1.LabelSelectorRequirement{affinitySelectorRequirement}
affinityLabelSelector := metav1.LabelSelector{MatchExpressions: affinityMatchExpressions}
podAffinityTerms := []corev1.PodAffinityTerm{corev1.PodAffinityTerm{LabelSelector: &affinityLabelSelector, TopologyKey: topologyKey}}
podAffinity := corev1.PodAffinity{RequiredDuringSchedulingIgnoredDuringExecution: podAffinityTerms}
if pod.Spec.Affinity != nil {
podAffinityPatch["path"] = "/spec/affinity/podAffinity"
podAffinityPatch["value"] = podAffinity
} else {
podAffinityPatch["path"] = "/spec/affinity"
podAffinityPatch["value"] = corev1.Affinity{PodAffinity: &podAffinity}
}
*patches = append(*patches, podAffinityPatch)
}
// checkWorkersMatchTopology returns whether the # of Ray TPU worker pods equals the # of hosts defined in the topology key
func checkWorkersMatchTopology(clusterName string, namespace string, workerGroupSpec ray.WorkerGroupSpec) (bool, error) {
klog.V(1).InfoS("checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "workerGroup", workerGroupSpec.GroupName)
numHosts := workerGroupSpec.NumOfHosts // 1 TPU VM host -> 1 Ray worker pod
if numHosts == 0 {
return false, errors.New("workerGroupSpec NumOfHosts not set")
}
groupName := workerGroupSpec.GroupName
containers := workerGroupSpec.Template.Spec.Containers
if len(containers) == 0 {
return false, errors.New("Container path not specified")
}
if containerRequestingTPUs(containers...) {
topology := workerGroupSpec.Template.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
klog.V(1).InfoS("checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "topology", topology, "NumOfHosts", numHosts)
if topology == "" {
err := errors.New("TPU topology not specified")
klog.ErrorS(err, "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
return false, err
}
chipsPerHost := getNumTPUChipsRequested(containers...)
if chipsPerHost == 0 {
err := errors.New("Container does not set TPU limits")
klog.ErrorS(err, "checkWorkersMatchTopology", "RayCluster", namespace+"/"+clusterName, "gke-tpu-topology", topology)
return false, err
}
expectedHosts, err := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, chipsPerHost)
if err != nil {
return false, err
}
if expectedHosts != numHosts {
return false, nil
}
}
return true, nil
}
// validateRayCluster returns an Admission Response after checking Ray worker groups match TPU scheduling constraints
func validateRayCluster(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) {
raycluster, err := extractRayCluster(admissionReview)
if err != nil {
return nil, err
}
admit := true
status := "Success"
message := ""
clusterName := raycluster.Name
namespace := raycluster.Namespace
klog.V(1).InfoS("validateRayCluster", "RayCluster", namespace+"/"+clusterName)
workerGroupSpecs := raycluster.Spec.WorkerGroupSpecs
for i := 0; i < len(workerGroupSpecs); i++ {
workerGroupSpec := workerGroupSpecs[i]
workerGroupContainers := workerGroupSpec.Template.Spec.Containers
if len(workerGroupContainers) != 0 && !containerRequestingTPUs(workerGroupContainers...) {
// pass through if no TPUs are requested
continue
}
// validate NumOfHosts for worker group matches topology nodeSelector
workersMatchTopology, err := checkWorkersMatchTopology(clusterName, namespace, workerGroupSpec)
if err != nil {
return nil, err
}
if !workersMatchTopology {
admit = false
status = "Failure"
message = "Number of workers in worker group not equal to specified topology"
break
}
}
// Create AdmissionResponse
admissionResponse := &admissionv1.AdmissionResponse{
UID: admissionReview.Request.UID,
Allowed: admit,
Result: &metav1.Status{
Status: status,
Message: message,
},
}
return admissionResponse, nil
}
// getEnvironmentVariable returns value associated with a given Container environment variable
func getEnvironmentVariable(varName string, container corev1.Container) string {
if container.Env != nil && len(container.Env) > 0 {
for _, envVar := range container.Env {
if envVar.Name == varName {
return envVar.Value
}
}
}
return ""
}
// getReplicaIndex returns the next lowest-index Pod Slice (worker group replica) to assign a Pod to in the RayCluster
// there are three possible cases here:
// 1. sliceToWorkerIDs is empty, this is the first pod the webhook intercepts
// - assign this pod to replica 0
// 2. The Pod Slice exists in sliceToWorkerIDs, but has # created workers < NumOfHosts
// - assign this pod to the lowest index replica with # created workers < NumOfHosts
// pods to the same replica
// 3. sliceToWorkerIDs isn't empty, but all slices have # workers == NumOfHosts
// - this occurs when the pod we intercept is the first pod of a different slice in the cluster
// - we keep track of how many replicas of the same worker group have been added to sliceToWorkerIDs
// so far, and assign this pod to the next integer replicaIndex
func getReplicaIndex(sliceToWorkerIDs map[slice][]int, clusterName string, groupName string, namespace string) int {
// first pod created in cluster
if len(sliceToWorkerIDs) == 0 {
return 0
}
nextLowestId := math.MaxInt32
numReplicas := 0 // tracks # of replicas in worker group created so far
for slice, workerList := range sliceToWorkerIDs {
if slice.clusterName == clusterName && slice.groupName == groupName && slice.namespace == namespace {
numReplicas++
createdPods := len(workerList)
if createdPods < int(slice.numOfHosts) {
if slice.replicaIndex < nextLowestId {
nextLowestId = slice.replicaIndex
}
}
}
}
// first pod of new slice in cluster
if nextLowestId == math.MaxInt32 {
nextLowestId = numReplicas
}
klog.V(1).InfoS("getReplicaIndex", "RayCluster", namespace+"/"+clusterName, "Worker Group", groupName, "Replica Index", nextLowestId)
return nextLowestId
}
// getNextWorkerID returns the next lowest TPU_WORKER_ID in the Pod Slice
func getNextWorkerID(sliceToWorkerIDs map[slice][]int, podSlice slice, namespace string, replicaIndex int) (int, error) {
tpuWorkerID := 0 // defaults to 0 (first Pod in slice)
if len(sliceToWorkerIDs) == 0 || len(sliceToWorkerIDs[podSlice]) == 0 {
return tpuWorkerID, nil
}
sort.Ints(sliceToWorkerIDs[podSlice])
// iterate through existing workers and get the next lowest, unused ID
lastID := 0
for index, workerID := range sliceToWorkerIDs[podSlice] {
// check for incorrect assignment of IDs
if index == 0 {
lastID = workerID
} else if workerID == lastID {
return 0, errors.New("Identical TPU_WORKER_ID assigned to multiple TPU workers in slice")
}
// get the next lowest, valid TPU_WORKER_ID
if workerID != tpuWorkerID {
break
}
lastID = workerID
tpuWorkerID++
}
klog.V(1).InfoS("getNextWorkerID", "RayCluster", namespace+"/"+podSlice.clusterName, "Worker Group", podSlice.groupName, "replicaIndex", replicaIndex, "TPU_WORKER_ID", tpuWorkerID)
return tpuWorkerID, nil
}
// getSliceToWorkerIDs returns a mapping representing the current RayCluster state of TPU pods using a PodLister
func (t *TPUWebhookServer) getSliceToWorkerIDs(clusterName string, groupName string, namespace string, numOfHosts int32) (map[slice][]int, error) {
sliceToWorkerIDs := make(map[slice][]int)
// we only care about workers in the same RayCluster and worker group when assigning IDs
podsInGroup, err := t.podLister.Pods(namespace).List(labels.SelectorFromSet(labels.Set{"ray.io/cluster": clusterName, "ray.io/group": groupName}))
if err != nil {
return nil, err
}
if podsInGroup == nil {
// return an empty mapping if no Pods with 'ray.io/group' label found
return sliceToWorkerIDs, nil
}
klog.V(1).InfoS("getSliceToWorkerIDs", "RayCluster", namespace+"/"+clusterName, "# Pods in Group", len(podsInGroup))
for _, existingPod := range podsInGroup {
if existingPod.DeletionTimestamp != nil {
continue
}
existingNamespace := existingPod.Namespace
// check that Pods are in the same namespace
if namespace != existingNamespace {
continue
}
if !containerRequestingTPUs(existingPod.Spec.Containers...) {
// Pod does not request TPUs, 'ray.io/group' is not a TPU worker group
return sliceToWorkerIDs, nil
}
replicaIndexLabel := existingPod.Labels["replicaIndex"]
if replicaIndexLabel == "" {
// Pod has not been intercepted by the KubeRay TPU webhook yet
continue
}
replicaIndexLabelValues := strings.Split(replicaIndexLabel, "-")
existingReplicaIndex, _ := strconv.Atoi(replicaIndexLabelValues[len(replicaIndexLabelValues)-1])
existingWorkerID := -1
for _, container := range existingPod.Spec.Containers {
if !containerRequestingTPUs(container) {
continue
}
tpuWorkerIDEnvVar := getEnvironmentVariable("TPU_WORKER_ID", container)
tempVar, err := strconv.Atoi(tpuWorkerIDEnvVar)
if err != nil {
klog.ErrorS(err, "getSliceToWorkerIDs", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_ID", tpuWorkerIDEnvVar)
continue
}
existingWorkerID = tempVar
break
}
if existingPod.Status.Phase == "Running" && existingWorkerID == -1 {
return nil, errors.New("existing TPU worker missing TPU_WORKER_ID")
}
if existingWorkerID != -1 {
// Pod has been intercepted by the webhook
podSlice := slice{clusterName, groupName, namespace, existingReplicaIndex, numOfHosts}
if sliceToWorkerIDs[podSlice] == nil {
sliceToWorkerIDs[podSlice] = []int{existingWorkerID}
} else {
sliceToWorkerIDs[podSlice] = append(sliceToWorkerIDs[podSlice], existingWorkerID)
}
klog.V(1).InfoS("getSliceToWorkerIDs", "RayCluster", namespace+"/"+clusterName, "ReplicaIndex", existingReplicaIndex, "TPU_WORKER_ID", existingWorkerID)
}
}
return sliceToWorkerIDs, nil
}
// extractPod returns a Pod unmarshalled from an Admission Request
func extractPod(admissionReview *admissionv1.AdmissionReview) (*corev1.Pod, error) {
if admissionReview.Request.Kind.Kind != "Pod" {
return nil, fmt.Errorf("Expected Pod but got %s", admissionReview.Request.Kind.Kind)
}
pod := corev1.Pod{}
if admissionReview.Request.Operation == "CREATE" {
if err := json.Unmarshal(admissionReview.Request.Object.Raw, &pod); err != nil {
return nil, err
}
}
return &pod, nil
}
// waitTimeout helper function to sync.WaitGroup Wait() or timeout
func waitTimeout(wait *sync.WaitGroup, timeout time.Duration) bool {
waitChan := make(chan struct{})
go func() {
defer close(waitChan)
wait.Wait()
}()
select {
case <-waitChan:
return true // Wait() returned
case <-time.After(timeout):
return false // Request timed out
}
}
// mutatePod returns an Admission Response after injecting TPU related fields to a given Pod
func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionReview) (*admissionv1.AdmissionResponse, error) {
pod, err := extractPod(admissionReview)
if err != nil {
return nil, err
}
var patches []patch
admissionResponse := &admissionv1.AdmissionResponse{
UID: admissionReview.Request.UID,
Allowed: true,
}
containers := pod.Spec.Containers
if containers == nil {
return nil, errors.New("Container path not specified")
}
if !containerRequestingTPUs(containers...) {
// if no TPUs are requested, simply admit the Pod
return admissionResponse, nil
}
// ray operator only sets GenerateName field - doesn't include random suffix until after admission request
// use mapping of {cluster name, group name, replicaIndex} -> workers to extract next TPU_WORKER_ID
clusterName := pod.Labels["ray.io/cluster"]
if clusterName == "" {
return nil, errors.New("Ray Pod created by KubeRay missing RayCluster label")
}
groupName := pod.Labels["ray.io/group"]
if groupName == "" {
return nil, errors.New("Ray Pod created by KubeRay missing Group label")
}
namespace := pod.Namespace
topology := pod.Spec.NodeSelector["cloud.google.com/gke-tpu-topology"]
if topology == "" {
return nil, errors.New("Ray Pod created by KubeRay missing TPU topology nodeSelector")
}
// assign worker to the next unique ID in the Pod Slice and update map
chipsPerHost := getNumTPUChipsRequested(containers...)
numOfHosts, _ := getNumTPUHostsFromTopology(clusterName, groupName, namespace, topology, chipsPerHost) // ignore error here because topology may not be set yet
// Wait for PodInformer cache to update from previous requests or timeout
if waitTimeout(&t.wg, time.Second*1) {
klog.V(1).Info("MutatePod", "PodInformer AddFunc called for prior admission request")
} else {
klog.V(1).Info("MutatePod", "Timed out waiting for PodInformer AddFunc")
}
// Add 1 to the WaitGroup to represent the pending Pod to the cache
defer t.wg.Add(1)
t.waiting += 1
// query k8s client to populate sliceToWorkerIDs to then calculate the next TPU_WORKER_ID and replicaIndex
sliceToWorkerIDs, err := t.getSliceToWorkerIDs(clusterName, groupName, namespace, numOfHosts)
if err != nil {
return nil, err
}
replicaIndex := getReplicaIndex(sliceToWorkerIDs, clusterName, groupName, namespace)
podSlice := slice{clusterName, groupName, namespace, replicaIndex, numOfHosts}
tpuWorkerID, err := getNextWorkerID(sliceToWorkerIDs, podSlice, namespace, replicaIndex) // defaults to 0 for single-host
if err != nil {
return nil, err
}
// set the unique identifier for the last admitted Pod by this TPUWebhookServer
t.lastAdmitted = fmt.Sprintf("%s-%s-%d-%d", namespace, clusterName, replicaIndex, tpuWorkerID)
// inject replica index label
injectReplicaLabel(clusterName, namespace, replicaIndex, groupName, &patches)
if numOfHosts > 1 {
// inject hostname into pod spec for DNS records
hostname := fmt.Sprintf(groupName+"-%d-%d", replicaIndex, tpuWorkerID)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "hostname", hostname)
hostnamePatch := patch{"op": "add"}
hostnamePatch["path"] = "/spec/hostname"
hostnamePatch["value"] = hostname
patches = append(patches, hostnamePatch)
// inject pod affinity/anti-affinity for scheduling
injectPodAffinity(pod, replicaIndex, groupName, &patches)
}
// inject all environment variables into the container requesting TPUs
for i := 0; i < len(containers); i++ {
container := containers[i]
if containerRequestingTPUs(container) {
path := fmt.Sprintf("/spec/containers/%d/env", i)
if numOfHosts > 1 {
// inject TPU_WORKER_HOSTNAMES
hostnames, err := genDNSHostnames(numOfHosts, groupName, clusterName, namespace, replicaIndex)
if err != nil {
return nil, err
}
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
injectHostnames(clusterName, hostnames, path, container, &patches)
}
// inject TPU_WORKER_ID
if getEnvironmentVariable("TPU_WORKER_ID", container) == "" {
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_ID", tpuWorkerID, "Replica Index", replicaIndex)
workerID := corev1.EnvVar{
Name: "TPU_WORKER_ID",
Value: fmt.Sprint(tpuWorkerID),
}
idPatch := patch{"op": "add"}
// create new EnvVar array if container.Env is empty, and append new EnvVars if not
if len(container.Env) == 0 {
idPatch["path"] = path
idPatch["value"] = []corev1.EnvVar{workerID}
} else {
idPatch["path"] = fmt.Sprintf("%s/-", path)
idPatch["value"] = workerID
}
patches = append(patches, idPatch)
}
// inject TPU_NAME
if getEnvironmentVariable("TPU_NAME", container) == "" {
tpuNameValue := fmt.Sprintf("%s-%d", groupName, replicaIndex)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_NAME", tpuNameValue, "Replica Index", replicaIndex)
tpuName := corev1.EnvVar{
Name: "TPU_NAME",
Value: tpuNameValue,
}
namePatch := patch{"op": "add"}
// create new EnvVar array if container.Env is empty, and append new EnvVars if not
if len(container.Env) == 0 {
namePatch["path"] = path
namePatch["value"] = []corev1.EnvVar{tpuName}
} else {
namePatch["path"] = fmt.Sprintf("%s/-", path)
namePatch["value"] = tpuName
}
patches = append(patches, namePatch)
}
}
}
patchBytes, err := json.Marshal(patches)
if err != nil {
return nil, err
}
admissionResponse.Patch = patchBytes
admissionResponse.PatchType = func() *admissionv1.PatchType {
pt := admissionv1.PatchTypeJSONPatch
return &pt
}()
return admissionResponse, nil
}
func writeCertfile(filename string, encodedData string) error {
data, err := base64.StdEncoding.DecodeString(encodedData)
if err != nil {
return err
}
_ = os.MkdirAll(filepath.Dir(filename), 0755)
return os.WriteFile(filename, data, 0644)
}
func init() {
flag.StringVar(&BindAddr, "bind-address", ":443", "Address to bind HTTPS service to")
flag.StringVar(&CACert, "ca-cert", "", "base64-encoded root certificate for TLS")
flag.StringVar(&ServerCert, "server-cert", "", "base64-encoded server certificate for TLS")
flag.StringVar(&ServerKey, "server-key", "", "base64-encoded server key for TLS")
flag.StringVar(&KubeConfigPath, "kube-config-path", "", "Kubernetes config path for k8s client")
// set klog verbosity level
klog.InitFlags(nil)
}
// isLastAdmittedPod returns True if Pod matches the last Pod admitted by the webhook server
func (t *TPUWebhookServer) isLastAdmittedPod(pod *corev1.Pod) (bool, error) {
if pod.Spec.Containers == nil || !containerRequestingTPUs(pod.Spec.Containers...) {
// Pod does not use TPUs
return false, nil
}
replicaIndex := pod.Labels["replicaIndex"]
if replicaIndex == "" {
// Pod was not mutated by the webhook
return false, nil
}
clusterName := pod.Labels["ray.io/cluster"]
if clusterName == "" {
return false, errors.New("Ray Pod created by KubeRay missing RayCluster label")
}
namespace := pod.Namespace
for _, container := range pod.Spec.Containers {
if !containerRequestingTPUs(container) {
// Skip to the next container
continue
}
tpuWorkerID := getEnvironmentVariable("TPU_WORKER_ID", container)
if tpuWorkerID == "" {
// TPU pod was not intercepted by the webhook
return false, nil
}
uniquePodID := fmt.Sprintf("%s-%s-%s-%s", namespace, clusterName, replicaIndex, tpuWorkerID)
if uniquePodID == t.lastAdmitted {
// Pod matches the last TPU worker Pod intercepted by the webhook server
return true, nil
}
}
return false, nil
}
// addPod allows next goroutine to start once the webhook PodInformer cache updates
func (t *TPUWebhookServer) addPod(obj interface{}) {
pod := obj.(*corev1.Pod)
klog.V(1).InfoS("addPod", "Pod", pod.Namespace+"/"+pod.Name, "Time", time.Now())
if t.lastAdmitted == "" {
// There is not a pending TPU worker Pod to the informer cache, unblock if waiting and return
for t.waiting > 0 {
t.wg.Done()
t.waiting -= 1
}
return
}
if t.waiting == 0 {
// Webhook is not waiting, no-op
return
}
// Check if Pod in cache is the last admitted TPU Pod
isLastAdmitted, err := t.isLastAdmittedPod(pod)
if err != nil {
klog.Errorf("Invalid addPod: %s", err)
return
}
if isLastAdmitted {
// Informer cache has been updated, unblock the next Mutate call
t.wg.Done()
t.waiting -= 1
}
}
func main() {
flag.Parse()
// use in-cluster config if kubeConfig path is not passed as a flag
var client *kubernetes.Clientset
if KubeConfigPath == "" {
config, err := rest.InClusterConfig()
if err != nil {
panic(err)
}
client = kubernetes.NewForConfigOrDie(config)
} else {
config, err := clientcmd.BuildConfigFromFlags("", KubeConfigPath)
if err != nil {
panic(err)
}
client = kubernetes.NewForConfigOrDie(config)
}
// instantiate PodInformer for Ray worker pods in the GKE cluster
tweakListOptionsFunc := func(options *metav1.ListOptions) {
options.LabelSelector = "ray.io/node-type=worker,app.kubernetes.io/created-by=kuberay-operator"
}
factory := informers.NewFilteredSharedInformerFactory(client, 1*time.Minute, metav1.NamespaceAll, tweakListOptionsFunc)
podInformer := factory.Core().V1().Pods().Informer()
// start the PodInformer and wait for cache sync
stopCh := make(chan struct{})
factory.Start(stopCh)
factory.WaitForCacheSync(stopCh)
if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced) {
klog.Fatal("Timed out waiting for PodInformer to sync")
}
podLister := factory.Core().V1().Pods().Lister()
if podLister == nil {
klog.Fatal("Failed to initialize Pod Lister")
}
// close the PodInformer on exit
defer close(stopCh)
tpuWebhookServer := NewTPUWebhookServer(podLister)
// Add custom event handler for Pod creation
podInformer.AddEventHandler(
cache.ResourceEventHandlerFuncs{
AddFunc: tpuWebhookServer.addPod,
},
)
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "kuberay-tpu-webhook")
})
mux.HandleFunc("/mutate", tpuWebhookServer.Mutate)
mux.HandleFunc("/validate", tpuWebhookServer.Validate)
srv := &http.Server{
Addr: BindAddr,
Handler: mux,
}
if ServerCert != "" && ServerKey != "" {
if err := writeCertfile(certPath, ServerCert); err != nil {
klog.Fatalf("write server cert: %v", err)
}
if err := writeCertfile(keyPath, ServerKey); err != nil {
klog.Fatalf("write server key: %v", err)
}
}
if err := srv.ListenAndServeTLS(certPath, keyPath); err != nil {
if err == http.ErrServerClosed {
klog.V(0).Info("Server closed")
return
}
klog.Error("Failed to start server")
}
}