tpu-provisioner/internal/cloud/gke.go (535 lines of code) (raw):
package cloud
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"math"
"net/http"
"strconv"
"strings"
"sync"
"time"
containerv1beta1 "google.golang.org/api/container/v1beta1"
"google.golang.org/api/googleapi"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/client-go/tools/record"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/log"
jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)
var log = logf.Log.WithName("provider")
const (
// GKE labels
GKETPUNodeSelector = "cloud.google.com/gke-tpu-topology"
GKEAcceleratorNodeSelector = "cloud.google.com/gke-tpu-accelerator"
GKENodePoolNameLabel = "cloud.google.com/gke-nodepool"
// ICIResiliencyLabel is used for disabling ICI resiliency, by default if not specified TPU slice
// is created in the ICI resilient mode. To disable the ICI resilient, workload needs
// to use node selector or affinity cloud.google.com/gke-tpu-ici-resiliency=false.
ICIResiliencyLabel = "cloud.google.com/gke-tpu-ici-resiliency"
// LocationHintLabel is used for passing in a desired borg cell the node pool MIG should be
// provisioned in.
LocationHintLabel = "cloud.google.com/gke-location-hint"
// Supported accelerator types
V4PodSliceAccelerator = "tpu-v4-podslice"
V5ePodSliceAccelerator = "tpu-v5-lite-podslice"
V5pPodSliceAccelerator = "tpu-v5p-slice"
V6eSliceAccelerator = "tpu-v6e-slice"
// Resource type labels
GoogleTPUResource = "google.com/tpu"
gcpLabelPrefix = "cloud.google.com/"
googleLabelPrefix = "google.com/"
// Default max pods per node is 110, but a lower value is necessary for large scale clusters,
// otherwise we'll run out of IP Space and provisioning will fail.
// 15 pods per node will work for small and large cluster sizes, given the TPU constraint of
// 1 pod per TPU node + kube-system pods
// TODO: move this to a environment variable
maxPodsPerNode = 15
// Constants for node pool naming conventions.
maxJobSetPrefixLength = 34
jobKeySuffixLength = 5
)
var _ Provider = &GKE{}
type GKE struct {
NodePools NodePoolService
ClusterContext GKEContext
Recorder record.EventRecorder
inProgressDeletesNPName sync.Map
inProgressCreatesNPName sync.Map
inProgressCreatesJobKey sync.Map
}
func (g *GKE) NodePoolLabelKey() string { return GKENodePoolNameLabel }
func (g *GKE) EnsureNodePoolForPod(p *corev1.Pod, why string) error {
np, err := g.nodePoolForPod(p)
if err != nil {
return fmt.Errorf("determining node pool for pod: %w", err)
}
existingNPState, err := g.checkExistingNodePool(context.TODO(), np)
if err != nil {
return fmt.Errorf("checking if node pool exists: %w", err)
}
log.Info("Checked existing node pool state",
"nodePoolName", np.Name, "existingNodePoolState", existingNPState.String(),
)
switch existingNPState {
case nodePoolStateNotExists:
// Create the node pool.
case nodePoolStateExistsAndMatches:
return nil
case nodePoolStateExistsAndNotMatches:
// Recreate the node pool.
const why = "existing node pool did not match pod and needed to be recreated"
if err := g.DeleteNodePool(np.Name, p, why); err != nil {
return fmt.Errorf("failed to delete node pool: %s: %w", why, err)
}
// Allow another reconcile cycle to create the new node pool.
return ErrNodePoolDeletedToBeRecreated
case nodePoolStateExistsAndStopping:
// Node pool is stopping, so we need to wait for it to be deleted before creating a new one.
return ErrNodePoolStopping
default:
return fmt.Errorf("unexpected node pool state: %v", existingNPState)
}
req := &containerv1beta1.CreateNodePoolRequest{
NodePool: np,
Parent: g.ClusterContext.ClusterName(),
}
// Due to concurrent reconciles, multiple creates for the same
// Node Pool will occur at the same time. The result is an error:
// "do: googleapi: Error 400: Cluster is running incompatible operation ..."
// To avoid a bunch of failed requests, we dedeuplicate here.
if _, inProgress := g.inProgressCreatesNPName.Load(np.Name); inProgress {
return fmt.Errorf("creation ongoing for node pool name: %v: %w", np.Name, ErrDuplicateRequest)
}
g.inProgressCreatesNPName.Store(np.Name, struct{}{})
defer g.inProgressCreatesNPName.Delete(np.Name)
// A restarting JobSet will trigger a new Node Pool creation.
// The current creation attempt might overlap with the previous one,
// which could still be ongoing, so we need to deduplicate.
// This works because job-key remains constant across restarts.
// NOTE: These checks dont work across controller restarts.
if jobKey := p.Labels[jobset.JobKey]; jobKey != "" {
if _, inProgress := g.inProgressCreatesJobKey.Load(jobKey); inProgress {
return fmt.Errorf("creation ongoing for job-key: %v: %w", jobKey, ErrDuplicateRequest)
}
g.inProgressCreatesJobKey.Store(jobKey, struct{}{})
defer g.inProgressCreatesJobKey.Delete(jobKey)
}
// Get JobSet this pod is part of from the pod labels and log it.
jobSetName := p.Labels[jobset.JobSetNameKey]
g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationStarted, "Starting creation of Node Pool %s (size = %v) for JobSet %s because %s", np.Name, np.InitialNodeCount, jobSetName, why)
log.Info(fmt.Sprintf("creating node pool %s for jobset %s", np.Name, jobSetName))
if err := g.NodePools.Create(context.TODO(), req, OpCallbacks{
ReqFailure: func(err error) {
g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Request to create Node Pool %s failed: %v.", np.Name, err)
},
OpFailure: func(err error) {
g.Recorder.Eventf(p, corev1.EventTypeWarning, EventNodePoolCreationFailed, "Operation to create Node Pool %s failed: %v.", np.Name, err)
},
Success: func() {
g.Recorder.Eventf(p, corev1.EventTypeNormal, EventNodePoolCreationSucceeded, "Successfully created Node Pool %s.", np.Name)
},
}); err != nil {
return err
}
return nil
}
func (g *GKE) ListNodePools() ([]NodePoolRef, error) {
var refs []NodePoolRef
resp, err := g.NodePools.List(context.TODO())
if err != nil {
return nil, fmt.Errorf("listing node pools: %w", err)
}
for _, np := range resp.NodePools {
jsName, exists := np.Config.Labels[LabelJobSetName]
if !exists {
jsName = np.Config.Labels[LabelProvisionerNodepoolID]
}
jsNamespace, exists := np.Config.Labels[LabelJobSetNamespace]
if !exists {
jsNamespace = "default"
}
refs = append(refs, NodePoolRef{
Name: np.Name,
Error: np.Status == "ERROR",
Message: np.StatusMessage,
CreatedForJobSet: types.NamespacedName{
Name: jsName,
Namespace: jsNamespace,
},
})
}
return refs, nil
}
func (g *GKE) DeleteNodePoolForNode(node *corev1.Node, why string) error {
name, ok := node.GetLabels()[g.NodePoolLabelKey()]
if !ok {
return fmt.Errorf("node %q does not have node pool label", node.Name)
}
return g.DeleteNodePool(name, node, why)
}
func (g *GKE) DeleteNodePool(name string, eventObj client.Object, why string) error {
// Due to concurrent reconciles, multiple deletes for the same
// Node Pool will occur at the same time. The result is an error:
// To avoid a bunch of failed requests, we dedeuplicate here.
if _, inProgress := g.inProgressDeletesNPName.Load(name); inProgress {
return ErrDuplicateRequest
}
g.inProgressDeletesNPName.Store(name, struct{}{})
defer g.inProgressDeletesNPName.Delete(name)
g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionStarted, "Starting deletion of Node Pool %s because %s", name, why)
if err := g.NodePools.Delete(context.TODO(), name, OpCallbacks{
NotFound: func() {
g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolNotFound, "Node pool not found - ignoring deletion attempt.", name)
},
ReqFailure: func(err error) {
g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Request to delete Node Pool %s failed: %v.", name, err)
},
OpFailure: func(err error) {
g.Recorder.Eventf(eventObj, corev1.EventTypeWarning, EventNodePoolDeletionFailed, "Operation to delete Node Pool %s failed: %v.", name, err)
},
Success: func() {
g.Recorder.Eventf(eventObj, corev1.EventTypeNormal, EventNodePoolDeletionSucceeded, "Successfully deleted Node Pool %s.", name)
},
}); err != nil {
return err
}
return nil
}
var ErrNodePoolStopping = errors.New("node pool stopping")
var ErrNodePoolDeletedToBeRecreated = errors.New("node pool deleted to be recreated")
type nodePoolState int
func (s nodePoolState) String() string {
switch s {
case nodePoolStateNotExists:
return "NotExists"
case nodePoolStateExistsAndMatches:
return "ExistsAndMatches"
case nodePoolStateExistsAndNotMatches:
return "ExistsAndNotMatches"
case nodePoolStateExistsAndStopping:
return "ExistsAndStopping"
default:
return "Unknown"
}
}
const (
nodePoolStateUnknown nodePoolState = iota
nodePoolStateNotExists
nodePoolStateExistsAndMatches
nodePoolStateExistsAndNotMatches
nodePoolStateExistsAndStopping
)
func (g *GKE) checkExistingNodePool(ctx context.Context, desired *containerv1beta1.NodePool) (nodePoolState, error) {
existing, err := g.NodePools.Get(ctx, desired.Name)
if err == nil {
match, err := nodePoolHashesMatch(desired, existing)
if err != nil {
return nodePoolStateUnknown, fmt.Errorf("comparing node pools: %w", err)
}
if match {
return nodePoolStateExistsAndMatches, nil
} else {
return nodePoolStateExistsAndNotMatches, nil
}
}
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
return nodePoolStateNotExists, nil
}
if existing.Status == "STOPPING" {
return nodePoolStateExistsAndStopping, nil
}
return nodePoolStateUnknown, err
}
func nodePoolHashesMatch(desired, existing *containerv1beta1.NodePool) (bool, error) {
desiredHash, ok := desired.Config.Labels[LabelNodePoolHash]
if !ok {
return false, fmt.Errorf("missing hash in desired node pool")
}
if existing.Config != nil && existing.Config.Labels != nil {
existingHash, ok := existing.Config.Labels[LabelNodePoolHash]
if !ok {
// Avoid recreating node pool if hash is missing.
// Node pool was likely provisioned by a legacy version of the provisioner.
return true, nil
}
return existingHash == desiredHash, nil
}
return true, nil
}
func (g *GKE) nodePoolForPod(p *corev1.Pod) (*containerv1beta1.NodePool, error) {
ref := metav1.GetControllerOf(p)
if ref == nil {
// TODO: Allow for standalone Pods?
return nil, errors.New("no owner reference")
}
jobSetName := p.Labels[jobset.JobSetNameKey]
if jobSetName == "" {
// This should never be reached due to the event filters in reconciler, but added just in case.
return nil, fmt.Errorf("pod %s is not part of a jobset, not constructing node pool config for it", p.Name)
}
labels := map[string]string{
// Used to keep track of what Node Pools this provisioner is responsible for.
LabelNodepoolManager: LabelNodepoolManagerTPUPodinator,
// Leave some bread crumbs:
LabelParentKind: strings.ToLower(ref.Kind),
LabelParentName: strings.ToLower(ref.Name),
// Assuming a Namespaced parent here...
LabelParentNamespace: strings.ToLower(p.Namespace),
LabelJobSetName: jobSetName,
LabelJobSetNamespace: p.Namespace,
}
// Copy configured labels from the Pod to the Node.
for _, key := range g.ClusterContext.PodToNodeLabels {
if val, ok := p.Labels[key]; ok {
labels[key] = val
}
}
// Copy labels specified by annotation to the Node.
for _, key := range strings.Split(getAnnotation(p, AnnotationCopyLabels), ",") {
key = strings.TrimSpace(key)
if key == "" {
continue
}
if val, ok := p.Labels[key]; ok {
labels[key] = val
}
}
for labelKey, labelValue := range p.Spec.NodeSelector {
switch labelKey {
case ICIResiliencyLabel:
labels[labelKey] = labelValue
case LocationHintLabel:
labels[labelKey] = labelValue
default:
// Don't copy GCP/Google labels onto the node.
if !strings.HasPrefix(labelKey, gcpLabelPrefix) && !strings.HasPrefix(labelKey, googleLabelPrefix) {
labels[labelKey] = labelValue
}
}
}
// Pod should already be filtered for this Node Selector at this point.
tpuTopo, ok := p.Spec.NodeSelector[GKETPUNodeSelector]
if !ok {
return nil, fmt.Errorf("missing node selector key: %v", GKETPUNodeSelector)
}
accel, ok := p.Spec.NodeSelector[GKEAcceleratorNodeSelector]
if !ok {
return nil, fmt.Errorf("missing node selector key: %v", GKEAcceleratorNodeSelector)
}
tpuRequest, err := sumTPURequests(p)
if err != nil {
return nil, fmt.Errorf("summing TPU requests: %w", err)
}
nodeCount, err := tpuTopologyToNodeCount(accel, tpuTopo)
if err != nil {
return nil, fmt.Errorf("determining node count: %w", err)
}
machineType, err := tpuMachineType(accel, tpuRequest)
if err != nil {
return nil, fmt.Errorf("determining node count: %w", err)
}
var reservation *containerv1beta1.ReservationAffinity
var taints []*containerv1beta1.NodeTaint
var spot bool
if !g.ClusterContext.ForceOnDemand {
if resName, ok := p.Spec.NodeSelector["cloud.google.com/reservation-name"]; ok {
var resVal string
resProj, ok := p.Spec.NodeSelector["cloud.google.com/reservation-project"]
if ok {
resVal = fmt.Sprintf("projects/%s/reservations/%s", resProj, resName)
} else {
resVal = resName
}
reservation = &containerv1beta1.ReservationAffinity{
ConsumeReservationType: "SPECIFIC_RESERVATION",
Key: "compute.googleapis.com/reservation-name",
Values: []string{
resVal,
},
}
}
spot = p.Spec.NodeSelector["cloud.google.com/gke-spot"] == "true"
if spot {
// Add the taint that NAP would add.
// https://cloud.google.com/kubernetes-engine/docs/concepts/spot-vms#spotvms-nap
taints = append(taints, &containerv1beta1.NodeTaint{
Key: "cloud.google.com/gke-spot",
Value: "true",
Effect: "NO_SCHEDULE",
})
}
}
var secondaryDisks []*containerv1beta1.SecondaryBootDisk
if g.ClusterContext.NodeSecondaryDisk != "" {
secondaryDisks = []*containerv1beta1.SecondaryBootDisk{
{
// Example: "projects/my-gcp-project/global/images/my-disk-image"
DiskImage: g.ClusterContext.NodeSecondaryDisk,
Mode: "CONTAINER_IMAGE_CACHE",
},
}
}
var networkConfig *containerv1beta1.NodeNetworkConfig
var additionalNodeNetworks []*containerv1beta1.AdditionalNodeNetworkConfig
// additional-node-networks: "vpc1:subnet1, vpc2:subnet2"
additionalNodeNetworksCSV := g.ClusterContext.NodeAdditionalNetworks
if getAnnotation(p, AnnotationAdditionalNodeNetworks) != "" {
additionalNodeNetworksCSV = getAnnotation(p, AnnotationAdditionalNodeNetworks)
}
for _, pair := range strings.Split(additionalNodeNetworksCSV, ",") {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}
netAndSubnet := strings.SplitN(pair, ":", 2)
if len(netAndSubnet) != 2 {
return nil, fmt.Errorf("invalid additional network annotation: %v", pair)
}
additionalNodeNetworks = append(additionalNodeNetworks, &containerv1beta1.AdditionalNodeNetworkConfig{
Network: strings.TrimSpace(netAndSubnet[0]),
Subnetwork: strings.TrimSpace(netAndSubnet[1]),
})
}
if len(additionalNodeNetworks) > 0 {
networkConfig = &containerv1beta1.NodeNetworkConfig{
AdditionalNodeNetworkConfigs: additionalNodeNetworks,
}
}
nodeServiceAccount := g.ClusterContext.NodeServiceAccount
if sa, ok := p.Annotations[AnnotationNodeServiceAccount]; ok {
nodeServiceAccount = sa
}
// placement policy is only valid in GKE for non "1t" shapes
placementPolicy := &containerv1beta1.PlacementPolicy{}
if !strings.HasSuffix(machineType, "1t") {
placementPolicy.TpuTopology = tpuTopo
placementPolicy.Type = "COMPACT"
}
var diskType string
if g.ClusterContext.NodeDiskType != "" {
diskType = g.ClusterContext.NodeDiskType
}
name, err := podToNodePoolName(p)
if err != nil {
return nil, err
}
np := &containerv1beta1.NodePool{
Name: name,
Config: &containerv1beta1.NodeConfig{
ServiceAccount: nodeServiceAccount,
ShieldedInstanceConfig: &containerv1beta1.ShieldedInstanceConfig{
EnableIntegrityMonitoring: true,
EnableSecureBoot: g.ClusterContext.NodeSecureBoot,
},
Tags: g.ClusterContext.NodeTags,
// NOTE: vendor/ was manually updated to include the field because
// it was not currently available at the time of writing:
SecondaryBootDisks: secondaryDisks,
MachineType: machineType,
ReservationAffinity: reservation,
Labels: labels,
Spot: spot,
Taints: taints,
BootDiskKmsKey: g.ClusterContext.NodeBootDiskKMSKey,
DiskType: diskType,
EnableConfidentialStorage: g.ClusterContext.NodeConfidentialStorage,
},
InitialNodeCount: int64(nodeCount),
Locations: []string{g.ClusterContext.NodeZone},
PlacementPolicy: placementPolicy,
Management: &containerv1beta1.NodeManagement{
AutoRepair: true,
AutoUpgrade: false,
},
UpgradeSettings: &containerv1beta1.UpgradeSettings{
MaxSurge: 1,
},
MaxPodsConstraint: &containerv1beta1.MaxPodsConstraint{MaxPodsPerNode: maxPodsPerNode},
NetworkConfig: networkConfig,
}
hash, err := nodePoolSelectiveHash(np)
if err != nil {
return nil, fmt.Errorf("hashing node pool: %w", err)
}
np.Config.Labels[LabelNodePoolHash] = hash
return np, nil
}
func sumTPURequests(p *corev1.Pod) (int, error) {
var n int
for _, c := range p.Spec.Containers {
if c.Resources.Requests == nil {
continue
}
req, ok := c.Resources.Requests[corev1.ResourceName(GoogleTPUResource)]
if !ok {
continue
}
v, ok := req.AsInt64()
if !ok {
return 0, fmt.Errorf(("invalid TPU request: %v"), req.String())
}
n += int(v)
}
return n, nil
}
// podToNodePoolName deterministically generates a node pool name for a given pod,
// by using the JobSet name and job-key (SHA1 hash of namespaced job key), as
// given in the pod labels.
// These labels are stable through JobSet restarts, so the node pool name
// generated here will be the same if the JobSet is restarted.
// Node pool name format is: {first 34 chars of jobset name}-{first 5 chars of job-key}
// This ensures node pool names are within the 40 char limit on node pool name size.
func podToNodePoolName(p *corev1.Pod) (string, error) {
jobSetName, exists := p.Labels[jobset.JobSetNameKey]
if !exists {
return "", fmt.Errorf("%s label not found on pod %s", jobset.JobSetNameKey, p.Name)
}
jobKey, exists := p.Labels[jobset.JobKey]
if !exists {
return "", fmt.Errorf("%s label not found on pod %s", jobset.JobKey, p.Name)
}
prefixLength := min(maxJobSetPrefixLength, len(jobSetName))
prefix := jobSetName[:prefixLength]
suffix := jobKey[:jobKeySuffixLength]
nodePoolName := fmt.Sprintf("%s-%s", prefix, suffix)
return nodePoolName, nil
}
func tpuTopologyToNodeCount(accelerator, topo string) (int, error) {
var expectedDims int
switch accelerator {
case V4PodSliceAccelerator, V5pPodSliceAccelerator:
expectedDims = 3
case V5ePodSliceAccelerator, V6eSliceAccelerator:
expectedDims = 2
default:
return 0, fmt.Errorf("invalid accelerator: %v", accelerator)
}
split := strings.Split(topo, "x")
if len(split) != expectedDims {
return 0, fmt.Errorf("invalid topology: %v, expected %v dimensions", topo, expectedDims)
}
product := 1
for _, s := range split {
x, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid topology: %v, could not convert %q to int: %w", topo, s, err)
}
product *= x
}
return int(math.Ceil(float64(product) / 4)), nil
}
// tpuMachineType takes an accelerator type (from nodeSelector) and a TPU request
// from container requests and returns the corresponding machine type.
func tpuMachineType(accel string, tpuRequest int) (string, error) {
if tpuRequest < 1 {
return "", fmt.Errorf("invalid TPU request: %v", tpuRequest)
}
switch accel {
case V4PodSliceAccelerator: // v4
return fmt.Sprintf("ct4p-hightpu-%vt", tpuRequest), nil
case V5ePodSliceAccelerator: // v5e
return fmt.Sprintf("ct5lp-hightpu-%vt", tpuRequest), nil
case V5pPodSliceAccelerator: // v5p
return fmt.Sprintf("ct5p-hightpu-%vt", tpuRequest), nil
case V6eSliceAccelerator: // v6e
return fmt.Sprintf("ct6e-standard-%vt", tpuRequest), nil
}
return "", fmt.Errorf("invalid accelerator: %v", accel)
}
func waitForGkeOp(svc *containerv1beta1.Service, c GKEContext, operation *containerv1beta1.Operation) error {
operationWaitTimeout := 30 * time.Minute
operationPollInterval := 5 * time.Second
for start := time.Now(); time.Since(start) < operationWaitTimeout; time.Sleep(operationPollInterval) {
if op, err := svc.Projects.Locations.Operations.Get(c.OpName(operation.Name)).Do(); err == nil {
if op.Status == "DONE" {
return nil
}
} else {
return fmt.Errorf("waiting for operation: %w", err)
}
}
return fmt.Errorf("timeout while waiting for operation %s on %s to complete", operation.Name, operation.TargetLink)
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func getAnnotation(p *corev1.Pod, key string) string {
if p.Annotations == nil {
return ""
}
return p.Annotations[key]
}
// nodePoolSelectiveHash attempts to hash information specific to workload requirements.
// A selective approach is taken to avoid overzealous node pool recreation under circumstances
// where values might change due to a config or code change in the provisioner.
// Example scenario where selective hashing is useful:
// 1. Provisioner is updated to include new upgrade settings.
// 2. Some node pool goes into a repairing state.
// 3. The workload Pod goes into an unschedulable state.
// 4. The code path for ensuring a matching node pool exists is executed.
func nodePoolSelectiveHash(np *containerv1beta1.NodePool) (string, error) {
h := fnv.New32a()
if np.Config != nil && np.Config.Labels != nil {
hash, ok := np.Config.Labels[LabelNodePoolHash]
if ok {
return hash, nil
}
}
npToHash := &containerv1beta1.NodePool{
Config: &containerv1beta1.NodeConfig{
Spot: np.Config.Spot,
Labels: np.Config.Labels,
MachineType: np.Config.MachineType,
ReservationAffinity: np.Config.ReservationAffinity,
},
}
jsn, err := json.Marshal(npToHash)
if err != nil {
return "", err
}
h.Write(jsn)
return rand.SafeEncodeString(fmt.Sprint(h.Sum32())), nil
}