router/pkg/authentication/jwks_token_decoder.go (84 lines of code) (raw):

package authentication import ( "context" "fmt" "net/http" "net/url" "time" "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" "go.uber.org/zap" "golang.org/x/time/rate" "github.com/wundergraph/cosmo/router/internal/httpclient" ) type TokenDecoder interface { Decode(token string) (Claims, error) } type jwksTokenDecoder struct { jwks keyfunc.Keyfunc } // Decode implements TokenDecoder. func (j *jwksTokenDecoder) Decode(tokenString string) (Claims, error) { token, err := jwt.Parse(tokenString, j.jwks.Keyfunc) if err != nil { return nil, fmt.Errorf("could not validate token: %w", err) } if !token.Valid { return nil, fmt.Errorf("token is invalid") } claims := token.Claims.(jwt.MapClaims) return Claims(claims), nil } type JWKSConfig struct { URL string RefreshInterval time.Duration AllowedAlgorithms []string } func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { remoteJWKSets := make(map[string]jwkset.Storage) for _, c := range configs { l := logger.With(zap.String("url", c.URL)) ur, err := url.ParseRequestURI(c.URL) if err != nil { return nil, fmt.Errorf("failed to parse given URL %q: %w", c.URL, err) } jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ Client: newOIDCDiscoveryClient(httpclient.NewRetryableHTTPClient(l)), Ctx: ctx, // Used to end background refresh goroutine. HTTPExpectedStatus: http.StatusOK, HTTPMethod: http.MethodGet, HTTPTimeout: 15 * time.Second, RefreshErrorHandler: func(ctx context.Context, err error) { l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms), } store, err := jwkset.NewStorageFromHTTP(ur, jwksetHTTPStorageOptions) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } remoteJWKSets[ur.String()] = store } jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ HTTPURLs: remoteJWKSets, PrioritizeHTTP: false, RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), } combined, err := jwkset.NewHTTPClient(jwksetHTTPClientOptions) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } keyfuncOptions := keyfunc.Options{ Ctx: ctx, Storage: combined, UseWhitelist: []jwkset.USE{jwkset.UseSig}, } jwks, err := keyfunc.New(keyfuncOptions) if err != nil { return nil, fmt.Errorf("error initializing JWK: %w", err) } return &jwksTokenDecoder{ jwks: jwks, }, nil }