tpu-provisioner/internal/controller/deletion_controller.go (154 lines of code) (raw):

package controller import ( "context" "errors" "fmt" "time" "github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner/internal/cloud" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" "sigs.k8s.io/controller-runtime/pkg/handler" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2" ) // DeletionReconciler watches Pods and Nodes and deletes Node Pools. type DeletionReconciler struct { client.Client Scheme *runtime.Scheme Recorder record.EventRecorder Provider cloud.Provider NodeCriteria NodeCriteria Concurrency int } type NodeCriteria struct { MinLifetime time.Duration // PoolDeletionDelay is the interval between the first and // second node pool deletion checks. Once the node pool deletion check // has passed twice, the node pool can be safely deleted. This second // check is ensure the node pool is not prematurely deleted, in the case // where a JobSet is restarted, but no pods have been created yet. PoolDeletionDelay time.Duration } //+kubebuilder:rbac:groups="",resources=nodes,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups="",resources=nodes/status,verbs=get;update;patch //+kubebuilder:rbac:groups="",resources=nodes/finalizers,verbs=update //+kubebuilder:rbac:groups="jobset.x-k8s.io",resources=jobsets,verbs=get;list;watch func (r *DeletionReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { lg := ctrllog.FromContext(ctx) lg.V(3).Info("Reconciling Node") var node corev1.Node if err := r.Get(ctx, req.NamespacedName, &node); err != nil { if apierrors.IsNotFound(err) { // Don't requeue, Node no longer exists (or does not exist in the cache). return ctrl.Result{}, nil } return ctrl.Result{}, fmt.Errorf("getting node: %w", err) } // NOTE: Because of the cache filter in main.go, this check should always evaluate to false. if node.GetLabels()[cloud.LabelNodepoolManager] != cloud.LabelNodepoolManagerTPUPodinator { lg.V(3).Info("Node was not provisioned by this controller, ignoring") return ctrl.Result{}, nil } // Avoid noisy reconciliation when nodes are shutting down. for _, c := range node.Status.Conditions { if c.Type == corev1.NodeReady && c.Status == corev1.ConditionFalse && c.Reason == "KubeletNotReady" && c.Message == "node is shutting down" { lg.V(3).Info("Node is shutting down, ignoring") return ctrl.Result{}, nil } } // Ensure node was not just created to make sure Pods have had time to schedule. if since := time.Since(node.GetCreationTimestamp().Time); since < r.NodeCriteria.MinLifetime { wait := r.NodeCriteria.MinLifetime - since + time.Second lg.V(3).Info("Node was just created, ignoring", "waiting", wait) return ctrl.Result{RequeueAfter: wait}, nil } nodePoolLabelKey := r.Provider.NodePoolLabelKey() nodePoolName, ok := node.GetLabels()[nodePoolLabelKey] if !ok { lg.V(3).Info("No node pool label found on node, ignoring", "labelKey", nodePoolLabelKey) return ctrl.Result{}, nil } // Ensure the JobSet whose pods created this node pool is either gone, completed, or failed before // deleting the node pool. jobSetName, exists := node.Labels[cloud.LabelJobSetName] if !exists { jobSetName, exists = node.Labels[cloud.LabelProvisionerNodepoolID] if !exists { lg.V(3).Info("Node missing jobset name label", "node", node.Name) return ctrl.Result{}, nil } } jobSetNamespace, exists := node.Labels[cloud.LabelJobSetNamespace] if !exists { lg.V(3).Info("Node missing jobset namespace label, using default", "node", node.Name) jobSetNamespace = "default" } var js jobset.JobSet if err := r.Get(ctx, types.NamespacedName{Name: jobSetName, Namespace: jobSetNamespace}, &js); err != nil { // Case 1: If JobSet no longer exists, delete the node pool. if apierrors.IsNotFound(err) { return r.deleteNodePool(ctx, &node, fmt.Sprintf("JobSet %s no longer exists", jobSetName)) } return ctrl.Result{}, err } // Case 2: if JobSet is in completed or failed state, delete node pool. if jobSetCompleted(&js) || jobSetFailed(&js) { return r.deleteNodePool(ctx, &node, fmt.Sprintf("JobSet %s execution has ended (completed or failed)", jobSetName)) } // No need to check all the other nodes, which will have the same jobset name label, we can end // the loop early. // Log the fact we are not deleting at a high verbosity level to avoid polluting logs but // allow for improved debugability. lg.V(5).Info("Node pool for JobSet is still in use, not deleting", "nodePoolName", nodePoolName, "jobSetName", jobSetName) return ctrl.Result{}, nil } // SetupWithManager sets up the controller with the Manager. func (r *DeletionReconciler) SetupWithManager(mgr ctrl.Manager) error { if err := mgr.GetFieldIndexer().IndexField(context.Background(), &corev1.Pod{}, ".spec.nodeName", func(rawObj client.Object) []string { pod := rawObj.(*corev1.Pod) return []string{pod.Spec.NodeName} }); err != nil { return err } if r.NodeCriteria.MinLifetime == 0 { return fmt.Errorf("NodeCriteria.MinLifetime must be set") } return ctrl.NewControllerManagedBy(mgr). For(&corev1.Node{}). Watches(&corev1.Pod{}, handler.EnqueueRequestsFromMapFunc(handler.MapFunc(nodeForPod))). WithOptions(controller.Options{ MaxConcurrentReconciles: r.Concurrency, }). WithEventFilter(predicate.NewPredicateFuncs(func(object client.Object) bool { node, ok := object.(*corev1.Node) return ok && nodeManagedByProvisioner(node) })). Complete(r) } func (r *DeletionReconciler) deleteNodePool(ctx context.Context, node *corev1.Node, reason string) (ctrl.Result, error) { lg := ctrllog.FromContext(ctx) if err := r.Provider.DeleteNodePoolForNode(node, reason); err != nil { if errors.Is(err, cloud.ErrDuplicateRequest) { lg.V(3).Info("Ignoring duplicate request to delete node pool") return ctrl.Result{}, nil } } return ctrl.Result{}, nil } func nodeForPod(ctx_ context.Context, obj client.Object) []reconcile.Request { pod := obj.(*corev1.Pod) if nodeName := pod.Spec.NodeName; nodeName != "" { return []reconcile.Request{ {NamespacedName: types.NamespacedName{Name: nodeName}}, } } return []reconcile.Request{} } // nodeManagedByProvisioner returns true if the given node is managed by the // TPU provisioner, otherwise it returns false. func nodeManagedByProvisioner(node *corev1.Node) bool { return node.Labels[cloud.LabelNodepoolManager] == cloud.LabelNodepoolManagerTPUPodinator } // jobSetCompleted returns true if the JobSet has completed, otherwise it returns false. func jobSetCompleted(js *jobset.JobSet) bool { for _, condition := range js.Status.Conditions { if condition.Type == string(jobset.JobSetCompleted) && condition.Status == metav1.ConditionTrue { return true } } return false } // jobSetFailed returns true if the JobSet has failed, otherwise it returns false. func jobSetFailed(js *jobset.JobSet) bool { for _, condition := range js.Status.Conditions { if condition.Type == string(jobset.JobSetFailed) && condition.Status == metav1.ConditionTrue { return true } } return false }