cmd/node-termination-handler.go (338 lines of code) (raw):

// Copyright 2016-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may // not use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. package main import ( "context" "fmt" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/aws/aws-node-termination-handler/pkg/config" "github.com/aws/aws-node-termination-handler/pkg/ec2metadata" "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/asg/launch" "github.com/aws/aws-node-termination-handler/pkg/interruptionevent/draincordon" "github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore" "github.com/aws/aws-node-termination-handler/pkg/logging" "github.com/aws/aws-node-termination-handler/pkg/monitor" "github.com/aws/aws-node-termination-handler/pkg/monitor/asglifecycle" "github.com/aws/aws-node-termination-handler/pkg/monitor/rebalancerecommendation" "github.com/aws/aws-node-termination-handler/pkg/monitor/scheduledevent" "github.com/aws/aws-node-termination-handler/pkg/monitor/spotitn" "github.com/aws/aws-node-termination-handler/pkg/monitor/sqsevent" "github.com/aws/aws-node-termination-handler/pkg/node" "github.com/aws/aws-node-termination-handler/pkg/observability" "github.com/aws/aws-node-termination-handler/pkg/webhook" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/ec2" "github.com/go-logr/zerologr" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/klog/v2" ) const ( scheduledMaintenance = "Scheduled Maintenance" spotITN = "Spot ITN" asgLifecycle = "ASG Lifecycle" rebalanceRecommendation = "Rebalance Recommendation" sqsEvents = "SQS Event" timeFormat = "2006/01/02 15:04:05" duplicateErrThreshold = 3 ) type interruptionEventHandler interface { HandleEvent(*monitor.InterruptionEvent) error } func main() { // Zerolog uses json formatting by default, so change that to a human-readable format instead log.Logger = log.Output(logging.RoutingLevelWriter{ Writer: &zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: timeFormat, NoColor: true}, ErrWriter: &zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: timeFormat, NoColor: true}, }) signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGTERM) defer signal.Stop(signalChan) nthConfig, err := config.ParseCliArgs() if err != nil { log.Fatal().Err(err).Msg("Failed to parse cli args,") } if nthConfig.JsonLogging { log.Logger = zerolog.New(os.Stderr).With().Timestamp().Logger() } switch strings.ToLower(nthConfig.LogLevel) { case "info": zerolog.SetGlobalLevel(zerolog.InfoLevel) case "debug": zerolog.SetGlobalLevel(zerolog.DebugLevel) case "error": zerolog.SetGlobalLevel(zerolog.ErrorLevel) } klog.SetLogger(zerologr.New(&log.Logger)) log.Info().Msgf("Using log format version %d", nthConfig.LogFormatVersion) if err = logging.SetFormatVersion(nthConfig.LogFormatVersion); err != nil { log.Warn().Err(err).Send() } if err = observability.SetReasonForKindVersion(nthConfig.LogFormatVersion); err != nil { log.Warn().Err(err).Send() } err = webhook.ValidateWebhookConfig(nthConfig) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Webhook validation failed,") } clusterConfig, err := rest.InClusterConfig() if err != nil { log.Fatal().Err(err).Msgf("retreiving cluster config") } clientset, err := kubernetes.NewForConfig(clusterConfig) if err != nil { log.Fatal().Err(err).Msgf("creating new clientset with config: %v", err) } node, err := node.New(nthConfig, clientset) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Unable to instantiate a node for various kubernetes node functions,") } metrics, initMetricsErr := observability.InitMetrics(nthConfig.EnablePrometheus, nthConfig.PrometheusPort) if initMetricsErr != nil { nthConfig.Print() log.Fatal().Err(initMetricsErr).Msg("Unable to instantiate observability metrics,") } err = observability.InitProbes(nthConfig.EnableProbes, nthConfig.ProbesPort, nthConfig.ProbesEndpoint) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Unable to instantiate probes service,") } imdsDisabled := nthConfig.EnableSQSTerminationDraining interruptionEventStore := interruptioneventstore.New(nthConfig) var imds *ec2metadata.Service var nodeMetadata ec2metadata.NodeMetadata if !imdsDisabled { imds = ec2metadata.New(nthConfig.MetadataURL, nthConfig.MetadataTries) nodeMetadata = imds.GetNodeMetadata() } // Populate the aws region if available from node metadata and not already explicitly configured if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" { nthConfig.AWSRegion = nodeMetadata.Region } else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" { nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL) log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion) } if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining { nthConfig.Print() log.Fatal().Msgf("Unable to find the AWS region to process queue events.") } recorder, err := observability.InitK8sEventRecorder(nthConfig.EmitKubernetesEvents, nthConfig.NodeName, nthConfig.EnableSQSTerminationDraining, nodeMetadata, nthConfig.KubernetesEventsExtraAnnotations, clientset) if err != nil { nthConfig.Print() log.Fatal().Err(err).Msg("Unable to create Kubernetes event recorder,") } nthConfig.Print() if !imdsDisabled && nthConfig.EnableScheduledEventDraining { //will retry 4 times with an interval of 2 seconds. pollCtx, cancelPollCtx := context.WithTimeout(context.Background(), 8*time.Second) err = wait.PollUntilContextCancel(pollCtx, 2*time.Second, true, func(context.Context) (done bool, err error) { err = handleRebootUncordon(nthConfig.NodeName, interruptionEventStore, *node) if err != nil { log.Warn().Err(err).Msgf("Unable to complete the uncordon after reboot workflow on startup, retrying") return false, nil } return true, nil }) if err != nil { log.Warn().Err(err).Msgf("All retries failed, unable to complete the uncordon after reboot workflow") } cancelPollCtx() } interruptionChan := make(chan monitor.InterruptionEvent) defer close(interruptionChan) cancelChan := make(chan monitor.InterruptionEvent) defer close(cancelChan) monitoringFns := map[string]monitor.Monitor{} if !imdsDisabled { if nthConfig.EnableSpotInterruptionDraining { imdsSpotMonitor := spotitn.NewSpotInterruptionMonitor(imds, interruptionChan, cancelChan, nthConfig.NodeName) monitoringFns[spotITN] = imdsSpotMonitor } if nthConfig.EnableASGLifecycleDraining { asgLifecycleMonitor := asglifecycle.NewASGLifecycleMonitor(imds, interruptionChan, cancelChan, nthConfig.NodeName) monitoringFns[asgLifecycle] = asgLifecycleMonitor } if nthConfig.EnableScheduledEventDraining { imdsScheduledEventMonitor := scheduledevent.NewScheduledEventMonitor(imds, interruptionChan, cancelChan, nthConfig.NodeName) monitoringFns[scheduledMaintenance] = imdsScheduledEventMonitor } if nthConfig.EnableRebalanceMonitoring || nthConfig.EnableRebalanceDraining { imdsRebalanceMonitor := rebalancerecommendation.NewRebalanceRecommendationMonitor(imds, interruptionChan, nthConfig.NodeName) monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor } } if nthConfig.EnableSQSTerminationDraining { cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) sess := session.Must(session.NewSessionWithOptions(session.Options{ Config: *cfg, SharedConfigState: session.SharedConfigEnable, })) creds, err := sess.Config.Credentials.Get() if err != nil { log.Fatal().Err(err).Msg("Unable to get AWS credentials") } log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName) ec2Client := ec2.New(sess) if initMetricsErr == nil && nthConfig.EnablePrometheus { go metrics.InitNodeMetrics(nthConfig, node, ec2Client) } completeLifecycleActionDelay := time.Duration(nthConfig.CompleteLifecycleActionDelaySeconds) * time.Second sqsMonitor := sqsevent.SQSMonitor{ CheckIfManaged: nthConfig.CheckTagBeforeDraining, ManagedTag: nthConfig.ManagedTag, QueueURL: nthConfig.QueueURL, InterruptionChan: interruptionChan, CancelChan: cancelChan, SQS: sqsevent.GetSqsClient(sess), ASG: autoscaling.New(sess), EC2: ec2Client, BeforeCompleteLifecycleAction: func() { <-time.After(completeLifecycleActionDelay) }, } monitoringFns[sqsEvents] = sqsMonitor } for _, fn := range monitoringFns { go func(monitor monitor.Monitor) { logging.VersionedMsgs.MonitoringStarted(monitor.Kind()) var previousErr error var duplicateErrCount int for range time.Tick(time.Second * 2) { err := monitor.Monitor() if err != nil { logging.VersionedMsgs.ProblemMonitoringForEvents(monitor.Kind(), err) metrics.ErrorEventsInc(monitor.Kind()) recorder.Emit(nthConfig.NodeName, observability.Warning, observability.MonitorErrReason, observability.MonitorErrMsgFmt, monitor.Kind()) if previousErr != nil && err.Error() == previousErr.Error() { duplicateErrCount++ } else { duplicateErrCount = 0 previousErr = err } if duplicateErrCount >= duplicateErrThreshold { log.Warn().Msg("Stopping NTH - Duplicate Error Threshold hit.") panic(fmt.Sprintf("%v", err)) } } } }(fn) } go watchForInterruptionEvents(interruptionChan, interruptionEventStore) log.Info().Msg("Started watching for interruption events") log.Info().Msg("Kubernetes AWS Node Termination Handler has started successfully!") go watchForCancellationEvents(cancelChan, interruptionEventStore, node, metrics, recorder) log.Info().Msg("Started watching for event cancellations") var wg sync.WaitGroup asgLaunchHandler := launch.New(interruptionEventStore, *node, nthConfig, metrics, recorder, clientset) drainCordonHander := draincordon.New(interruptionEventStore, *node, nthConfig, nodeMetadata, metrics, recorder) for range time.NewTicker(1 * time.Second).C { select { case <-signalChan: // Exit interruption loop if a SIGTERM is received or the channel is closed break default: EventLoop: for event, ok := interruptionEventStore.GetActiveEvent(); ok; event, ok = interruptionEventStore.GetActiveEvent() { select { case interruptionEventStore.Workers <- 1: logging.VersionedMsgs.RequestingInstanceDrain(event) event.InProgress = true wg.Add(1) recorder.Emit(event.NodeName, observability.Normal, observability.GetReasonForKind(event.Kind, event.Monitor), event.Description) go processInterruptionEvent(interruptionEventStore, event, []interruptionEventHandler{asgLaunchHandler, drainCordonHander}, &wg) default: log.Warn().Msg("all workers busy, waiting") break EventLoop } } } } log.Info().Msg("AWS Node Termination Handler is shutting down") wg.Wait() log.Debug().Msg("all event processors finished") } func handleRebootUncordon(nodeName string, interruptionEventStore *interruptioneventstore.Store, node node.Node) error { isLabeled, err := node.IsLabeledWithAction(nodeName) if err != nil { return err } if !isLabeled { return nil } eventID, err := node.GetEventID(nodeName) if err != nil { return err } err = node.UncordonIfRebooted(nodeName) if err != nil { return fmt.Errorf("Unable to complete node label actions: %w", err) } interruptionEventStore.IgnoreEvent(eventID) return nil } func watchForInterruptionEvents(interruptionChan <-chan monitor.InterruptionEvent, interruptionEventStore *interruptioneventstore.Store) { for { interruptionEvent := <-interruptionChan interruptionEventStore.AddInterruptionEvent(&interruptionEvent) } } func watchForCancellationEvents(cancelChan <-chan monitor.InterruptionEvent, interruptionEventStore *interruptioneventstore.Store, node *node.Node, metrics observability.Metrics, recorder observability.K8sEventRecorder) { for { interruptionEvent := <-cancelChan nodeName := interruptionEvent.NodeName eventID := interruptionEvent.EventID interruptionEventStore.CancelInterruptionEvent(interruptionEvent.EventID) if interruptionEventStore.ShouldUncordonNode(nodeName) { log.Info().Msg("Uncordoning the node due to a cancellation event") err := node.Uncordon(nodeName) if err != nil { log.Err(err).Msg("Uncordoning the node failed") recorder.Emit(nodeName, observability.Warning, observability.UncordonErrReason, observability.UncordonErrMsgFmt, err.Error()) } else { recorder.Emit(nodeName, observability.Normal, observability.UncordonReason, observability.UncordonMsg) } metrics.NodeActionsInc("uncordon", nodeName, eventID, err) err = node.RemoveNTHLabels(nodeName) if err != nil { log.Warn().Err(err).Msg("There was an issue removing NTH labels from node") } err = node.RemoveNTHTaints(nodeName) if err != nil { log.Warn().Err(err).Msg("There was an issue removing NTH taints from node") } } else { log.Info().Msg("Another interruption event is active, not uncordoning the node") } } } func processInterruptionEvent(interruptionEventStore *interruptioneventstore.Store, event *monitor.InterruptionEvent, eventHandlers []interruptionEventHandler, wg *sync.WaitGroup) { defer wg.Done() if event == nil { log.Error().Msg("processing nil interruption event") <-interruptionEventStore.Workers return } var err error for _, eventHandler := range eventHandlers { err = eventHandler.HandleEvent(event) if err != nil { log.Error().Err(err).Interface("event", event).Msg("handling event") } } <-interruptionEventStore.Workers } func getRegionFromQueueURL(queueURL string) string { for _, partition := range endpoints.DefaultPartitions() { for regionID := range partition.Regions() { if strings.Contains(queueURL, regionID) { return regionID } } } return "" }