internal/validation/request.go (80 lines of code) (raw):

package validation import ( "context" "fmt" "github.com/golang-jwt/jwt/v5" _ "github.com/golang-jwt/jwt/v5" "go.amzn.com/eks/eks-pod-identity-agent/configuration" "go.amzn.com/eks/eks-pod-identity-agent/internal/middleware/logger" "go.amzn.com/eks/eks-pod-identity-agent/pkg/credentials" "go.amzn.com/eks/eks-pod-identity-agent/pkg/errors" "net" ) // A RequestValidator validates the requests that are expected by the agent type RequestValidator interface { ValidateEksCredentialRequest(ctx context.Context, credsRequest *credentials.EksCredentialsRequest) error } type DefaultCredentialValidator struct { // TargetHosts indicates which IP address we expect the call to come from // If not specified, we will use configuration.DefaultIpv4TargetHost and // configuration.DefaultIpv6TargetHost TargetHosts []string } var ( jwtParser = jwt.NewParser() jwtValidator = jwt.NewValidator() defaultValidTargetHosts = []string{ configuration.DefaultIpv4TargetHost, configuration.DefaultIpv6TargetHost, } ) // ValidateEksCredentialRequest is called to validate whether a request from the user is valid or not func (cv DefaultCredentialValidator) ValidateEksCredentialRequest(ctx context.Context, credsRequest *credentials.EksCredentialsRequest) error { log := logger.FromContext(ctx) log.Debugf("validating call to requested target host %s", credsRequest.RequestTargetHost) err := cv.validateRequestTargetHost(ctx, credsRequest.RequestTargetHost) if err != nil { return err } err = cv.validateToken(credsRequest) if err != nil { return err } log.Debug("validation passed") return nil } // validateToken checks if the JWT token is parseable func (cv DefaultCredentialValidator) validateToken(credsRequest *credentials.EksCredentialsRequest) error { // just verify the token is parseable, we will detect if it's valid or not on the service if credsRequest.ServiceAccountToken == "" { return errors.NewRequestValidationError("Service account token cannot be empty") } parsedToken, _, err := jwtParser.ParseUnverified(credsRequest.ServiceAccountToken, &jwt.RegisteredClaims{}) if err != nil { return errors.NewRequestValidationError(fmt.Sprintf("Service account token cannot be parsed: %v", err)) } err = jwtValidator.Validate(parsedToken.Claims) if err != nil { return errors.NewRequestValidationError(fmt.Sprintf("Service account token failed basic claim validations: %v", err)) } return nil } // validateRequestTargetHost checks whether the request address matches the // assign bind address for the agent func (cv DefaultCredentialValidator) validateRequestTargetHost(ctx context.Context, requestTargetHost string) error { // sometimes the port is included in the requestTargetHost, (eg when the port we are listening on // is not HTTP's default 80) log := logger.FromContext(ctx).WithField("target-host", requestTargetHost) if host, port, err := net.SplitHostPort(requestTargetHost); err == nil { log.WithFields(map[string]interface{}{ "host": host, "port": port, }).Tracef("Parsing request target host as host-port addr") requestTargetHost = host } // sometimes IPv6 host is expressed as "[fe00::]" so we want to drop the brackets if len(requestTargetHost) > 1 && requestTargetHost[0] == '[' && requestTargetHost[len(requestTargetHost)-1] == ']' { requestTargetHost = requestTargetHost[1 : len(requestTargetHost)-1] } // if all else fails we may have some custom target host that we don't know how to parse, eg localhost or some // dns address that might fail validation. Unit tests bind use localhost so we will leave this as is. log.Trace("Interpreting request target host without port") desiredTargetHosts := defaultValidTargetHosts if cv.TargetHosts != nil { desiredTargetHosts = cv.TargetHosts } for _, desiredTargetHost := range desiredTargetHosts { if desiredTargetHost == requestTargetHost { return nil } } return errors.NewAccessDeniedError( fmt.Sprintf( "Called agent through invalid address, please use either %s address not %s", desiredTargetHosts, requestTargetHost)) }