internal/node/pods.go (102 lines of code) (raw):

package node import ( "context" "fmt" "io/fs" "os" "path/filepath" "github.com/pkg/errors" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/utils/strings/slices" "sigs.k8s.io/yaml" ) type podFilter func([]corev1.Pod) ([]corev1.Pod, error) func daemonSetFilter(pods []corev1.Pod) ([]corev1.Pod, error) { var filteredPods []corev1.Pod for _, pod := range pods { controllerRef := metav1.GetControllerOf(&pod) if controllerRef == nil || controllerRef.Kind != appsv1.SchemeGroupVersion.WithKind("DaemonSet").Kind { filteredPods = append(filteredPods, pod) } } return filteredPods, nil } func staticPodsFilter(pods []corev1.Pod) ([]corev1.Pod, error) { var filteredPods []corev1.Pod staticPods, err := getStaticPodsOnNode() if err != nil { return nil, err } // If there are no static pods, there is nothing to filter if len(staticPods) == 0 { return pods, nil } for _, pod := range pods { if !slices.Contains(staticPods, pod.Name) { filteredPods = append(filteredPods, pod) } } return filteredPods, nil } func getDrainedPodFilters() []podFilter { return []podFilter{ daemonSetFilter, staticPodsFilter, } } func getStaticPodsOnNode() ([]string, error) { var staticPodNames []string files, err := os.ReadDir(defaultStaticPodManifestPath) if err != nil { // If manifest directory doesn't exist, there are no static pods. if errors.Is(err, fs.ErrNotExist) { return []string{}, nil } return nil, errors.Wrap(err, "failed to read static manifest directory") } for _, file := range files { extension := filepath.Ext(file.Name()) if extension == ".yaml" || extension == ".yml" { fileData, err := os.ReadFile(filepath.Join(defaultStaticPodManifestPath, file.Name())) if err != nil { return nil, err } var obj metav1.ObjectMeta if err := yaml.Unmarshal(fileData, &obj); err != nil { return nil, errors.Wrapf(err, "failed to unmarshal static pod manifest file: %s", file.Name()) } staticPodNames = append(staticPodNames, obj.Name) } } return staticPodNames, nil } // GetPodsOnNode makes 5 attempts by default to list pods before erroring unless it times out. func GetPodsOnNode(ctx context.Context, nodeName string, clientset kubernetes.Interface, options ...NodeValidationOption) ([]corev1.Pod, error) { opts := DefaultNodeValidationOptions() for _, option := range options { option(&opts) } var pods *corev1.PodList var err error consecutiveErrors := 0 err = wait.PollUntilContextTimeout(ctx, opts.ValidationInterval, opts.ValidationTimeout, true, func(ctx context.Context) (bool, error) { pods, err = clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{ FieldSelector: fmt.Sprintf("spec.nodeName=%s", nodeName), }, ) if err != nil { consecutiveErrors += 1 if consecutiveErrors == opts.MaxRetries { return false, errors.Wrap(err, "failed to list all pods running on the node") } return false, nil // continue polling } return true, nil }) if err != nil { return nil, err } return pods.Items, nil }