cmd/server.go (126 lines of code) (raw):

package cmd import ( "context" "fmt" "os" "os/signal" "sync" "syscall" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/eksauth" "github.com/sirupsen/logrus" "github.com/spf13/cobra" "go.amzn.com/eks/eks-pod-identity-agent/configuration" "go.amzn.com/eks/eks-pod-identity-agent/internal/middleware/logger" "go.amzn.com/eks/eks-pod-identity-agent/internal/sharedcredsrotater" "go.amzn.com/eks/eks-pod-identity-agent/pkg/handlers" "go.amzn.com/eks/eks-pod-identity-agent/pkg/server" ) var ( serverPort uint16 probePort uint16 metricsAddress string metricsPort uint16 bindHosts []string clusterName string overrideEksAuthEndpoint string maxCredentialRenewal time.Duration maxCacheSize int refreshQps int rotateCredentials bool ) var serverCmd = &cobra.Command{ Use: "server", Short: "A proxy server that exchanges kubernetes service account token with temporary AWS credentials by calling EKS Auth APIs", Long: fmt.Sprintf(`This command initalizes a proxy server that will listen by default on port %d. Request that are sent to the credential path (/v1/credentials) will be proxied to EKS to fetch temporary AWS credentials. The AWS SDKs used from within EKS workloads can be configured to invoke this endpoint for granular IAM permissions. Example use: './eks-pod-identity-agent server'`, serverPort), Run: func(cmd *cobra.Command, args []string) { ctx := context.Background() log := logger.FromContext(ctx) cfg, err := config.LoadDefaultConfig(ctx) if overrideEksAuthEndpoint != "" { overrideEndpointInCfg(log, &cfg, overrideEksAuthEndpoint) } if err != nil { log.Fatal("Unable to initialize aws configuration, exiting") } if rotateCredentials { log.Info("Credentials rotation enabled. Creds will be fetched and rotated from shared credentials file") cfg.Credentials = aws.NewCredentialsCache(sharedcredsrotater.NewRotatingSharedCredentialsProvider()) } startServers(ctx, cfg) }, } func startServers(pCtx context.Context, cfg aws.Config) { ctx, cancel := context.WithCancel(pCtx) wg := sync.WaitGroup{} servers := createServers(cfg) // start servers for _, srv := range servers { wg.Add(1) go func(server *server.Server, childCtx context.Context) { defer wg.Done() server.ListenUntilContextCancelled(childCtx) }(srv, logger.ContextWithField(ctx, "bind-addr", srv.Addr())) } // Create a channel to listen for an interrupt or terminate signal from the operating system // syscall.SIGTERM is equivalent to kill which allows the process time to cleanup quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT) <-quit cancel() wg.Wait() } func createServers(cfg aws.Config) []*server.Server { servers := make([]*server.Server, len(bindHosts)) // listen on all bindHosts for i, ip := range bindHosts { addr := fmt.Sprintf("%s:%d", ip, serverPort) servers[i] = server.NewEksCredentialServer(addr, handlers.EksCredentialHandlerOpts{ Cfg: cfg, ClusterName: clusterName, CredentialRenewal: maxCredentialRenewal, MaxCacheSize: maxCacheSize, RefreshQPS: refreshQps, }) } // add health probes listening on host's network servers = append(servers, server.NewProbeServer(fmt.Sprintf("localhost:%d", probePort), bindHosts, serverPort)) servers = append(servers, server.NewMetricsServer(fmt.Sprintf("%s:%d", metricsAddress, metricsPort), bindHosts, serverPort)) return servers } func overrideEndpointInCfg(log *logrus.Entry, cfg *aws.Config, endpoint string) { log.Printf("Overriding %s default endpoint with %s\n", eksauth.ServiceID, endpoint) cfg.EndpointResolverWithOptions = aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { if service == eksauth.ServiceID { return aws.Endpoint{ PartitionID: "aws", URL: endpoint, SigningRegion: region, }, nil } return aws.Endpoint{}, &aws.EndpointNotFoundError{} }) } func init() { rootCmd.AddCommand(serverCmd) // Read cluster name for CLI. This flag must be provided serverCmd.Flags().StringVarP(&clusterName, "cluster-name", "c", "", "Name of the EKS Cluster the agent will run on") err := serverCmd.MarkFlagRequired("cluster-name") if err != nil { panic(fmt.Sprintf("Unable to configure server command flags: %v", err)) } // Setup the port where the proxy server will listen to connections serverCmd.Flags().Uint16VarP(&serverPort, "port", "p", 80, "Listening port of the proxy server") serverCmd.Flags().Uint16Var(&probePort, "probe-port", 2703, "Health and readiness listening port") serverCmd.Flags().StringVar(&metricsAddress, "metrics-address", "0.0.0.0", "Metrics listening address") serverCmd.Flags().Uint16Var(&metricsPort, "metrics-port", 2705, "Metrics listening port") serverCmd.Flags().DurationVar(&maxCredentialRenewal, "max-credential-retention-before-renewal", 3*time.Hour, "Maximum amount of time that agent waits before renewing credentials. Set 0 to disable caching.") serverCmd.Flags().IntVar(&maxCacheSize, "max-cache-size", 2000, "Maximum amount of unique credentials to cache. Set 0 to disable caching.") serverCmd.Flags().IntVar(&refreshQps, "max-service-qps", 3, "Maximum amount of queries per second to EKS Auth") serverCmd.Flags().StringArrayVarP(&bindHosts, "bind-hosts", "b", []string{configuration.DefaultIpv4TargetHost, "[" + configuration.DefaultIpv6TargetHost + "]"}, "Hosts to bind server to") serverCmd.Flags().BoolVar(&rotateCredentials, "rotate-credentials", false, "Enable credentials rotation from shared credentials file") serverCmd.Flags().StringVar(&overrideEksAuthEndpoint, "endpoint", "", "Override for EKS auth endpoint") }