router/pkg/authentication/jwks_token_decoder.go (114 lines of code) (raw):

package authentication import ( "context" "fmt" "net/http" "time" "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" "go.uber.org/zap" "golang.org/x/time/rate" "github.com/wundergraph/cosmo/router/internal/httpclient" ) type TokenDecoder interface { Decode(token string) (Claims, error) } type jwksTokenDecoder struct { jwks keyfunc.Keyfunc } // Decode implements TokenDecoder. func (j *jwksTokenDecoder) Decode(tokenString string) (Claims, error) { token, err := jwt.Parse(tokenString, j.jwks.Keyfunc) if err != nil { return nil, fmt.Errorf("could not validate token: %w", err) } if !token.Valid { return nil, fmt.Errorf("token is invalid") } claims := token.Claims.(jwt.MapClaims) return Claims(claims), nil } type JWKSConfig struct { URL string RefreshInterval time.Duration AllowedAlgorithms []string Secret string Algorithm string KeyId string } func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { remoteJWKSets := make(map[string]jwkset.Storage) given := jwkset.NewMemoryStorage() for _, c := range configs { if c.URL != "" { l := logger.With(zap.String("url", c.URL)) jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ Client: newOIDCDiscoveryClient(httpclient.NewRetryableHTTPClient(l)), Ctx: ctx, // Used to end background refresh goroutine. HTTPExpectedStatus: http.StatusOK, HTTPMethod: http.MethodGet, HTTPTimeout: 15 * time.Second, RefreshErrorHandler: func(_ context.Context, err error) { l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms), } store, err := jwkset.NewStorageFromHTTP(c.URL, jwksetHTTPStorageOptions) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } remoteJWKSets[c.URL] = store } else if c.Secret != "" { marshalOptions := jwkset.JWKMarshalOptions{ Private: true, } if len(c.Secret) < 32 { logger.Warn("Using a short secret for JWKs may lead to weak security. Consider using a longer secret.") } alg := jwkset.ALG(c.Algorithm) if !alg.IANARegistered() { return nil, fmt.Errorf("unsupported algorithm: %s", c.Algorithm) } metadata := jwkset.JWKMetadataOptions{ ALG: alg, KID: c.KeyId, USE: jwkset.UseSig, } jwkOptions := jwkset.JWKOptions{ Marshal: marshalOptions, Metadata: metadata, } jwk, err := jwkset.NewJWKFromKey([]byte(c.Secret), jwkOptions) if err != nil { return nil, fmt.Errorf("failed to create JWK from secret: %w", err) } err = given.KeyWrite(ctx, jwk) if err != nil { return nil, fmt.Errorf("failed to write JWK to storage: %w", err) } } } jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ Given: given, HTTPURLs: remoteJWKSets, PrioritizeHTTP: false, RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), } combined, err := jwkset.NewHTTPClient(jwksetHTTPClientOptions) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } keyfuncOptions := keyfunc.Options{ Ctx: ctx, Storage: combined, UseWhitelist: []jwkset.USE{jwkset.UseSig}, } jwks, err := keyfunc.New(keyfuncOptions) if err != nil { return nil, fmt.Errorf("error initializing JWK: %w", err) } return &jwksTokenDecoder{ jwks: jwks, }, nil }