router/core/graph_server.go (1,429 lines of code) (raw):

package core import ( "bytes" "context" "crypto/ecdsa" "errors" "fmt" "net/http" "net/url" "path/filepath" "runtime" "strings" "sync" "time" "github.com/cespare/xxhash/v2" "github.com/cloudflare/backoff" "github.com/dgraph-io/ristretto/v2" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" "github.com/klauspost/compress/gzhttp" "github.com/klauspost/compress/gzip" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "go.opentelemetry.io/otel/attribute" otelmetric "go.opentelemetry.io/otel/metric" oteltrace "go.opentelemetry.io/otel/trace" "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/common" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/internal/circuit" "github.com/wundergraph/cosmo/router/internal/expr" rjwt "github.com/wundergraph/cosmo/router/internal/jwt" rmiddleware "github.com/wundergraph/cosmo/router/internal/middleware" "github.com/wundergraph/cosmo/router/internal/recoveryhandler" "github.com/wundergraph/cosmo/router/internal/requestlogger" "github.com/wundergraph/cosmo/router/internal/retrytransport" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/execution_config" "github.com/wundergraph/cosmo/router/pkg/grpcconnector" "github.com/wundergraph/cosmo/router/pkg/health" "github.com/wundergraph/cosmo/router/pkg/logging" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" "github.com/wundergraph/cosmo/router/pkg/otel" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/statistics" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" ) const ( featureFlagHeader = "X-Feature-Flag" featureFlagCookie = "feature_flag" ) type ( // Server is the public interface of the server. Server interface { HttpServer() *http.Server HealthChecks() health.Checker } // graphServer is the swappable implementation of a Graph instance which is an HTTP mux with middlewares. // Everytime a schema is updated, the old graph server is shutdown and a new graph server is created. // For feature flags, a graphql server has multiple mux and is dynamically switched based on the feature flag header or cookie. // All fields are shared between all feature muxes. On shutdown, all graph instances are shutdown. graphServer struct { *Config context context.Context cancelFunc context.CancelFunc storageProviders *config.StorageProviders engineStats statistics.EngineStatistics playgroundHandler func(http.Handler) http.Handler publicKey *ecdsa.PublicKey baseTransport *http.Transport subgraphTransports map[string]*http.Transport baseOtelAttributes []attribute.KeyValue baseRouterConfigVersion string mux *chi.Mux // inFlightRequests is used to track the number of requests currently being processed // does not include websocket (hijacked) connections inFlightRequests *atomic.Uint64 graphMuxList []*graphMux graphMuxListLock sync.Mutex runtimeMetrics *rmetric.RuntimeMetrics otlpEngineMetrics *rmetric.EngineMetrics prometheusEngineMetrics *rmetric.EngineMetrics connectionMetrics *rmetric.ConnectionMetrics instanceData InstanceData pubSubProviders []datasource.Provider traceDialer *TraceDialer connector *grpcconnector.Connector circuitBreakerManager *circuit.Manager } ) // BuildGraphMuxOptions contains the configuration options for building a graph mux. type BuildGraphMuxOptions struct { FeatureFlagName string RouterConfigVersion string EngineConfig *nodev1.EngineConfiguration ConfigSubgraphs []*nodev1.Subgraph RoutingUrlGroupings map[string]map[string]bool } func (b BuildGraphMuxOptions) IsBaseGraph() bool { return b.FeatureFlagName == "" } // newGraphServer creates a new server instance. func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterConfig, proxy ProxyFunc) (*graphServer, error) { /* Older versions of composition will not populate a compatibility version. * Currently, all "old" router execution configurations are compatible as there have been no breaking * changes. * Upon the first breaking change to the execution config, an unpopulated compatibility version will * also be unsupported (and the logic for IsRouterCompatibleWithExecutionConfig will need to be updated). */ if !execution_config.IsRouterCompatibleWithExecutionConfig(r.logger, routerConfig.CompatibilityVersion) { return nil, fmt.Errorf(`the compatibility version "%s" is not compatible with this router version`, routerConfig.CompatibilityVersion) } isConnStoreEnabled := r.Config.metricConfig.OpenTelemetry.ConnectionStats || r.Config.metricConfig.Prometheus.ConnectionStats var traceDialer *TraceDialer if isConnStoreEnabled { traceDialer = NewTraceDialer() } // Base transport baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "") // Subgraph transports subgraphTransports := map[string]*http.Transport{} for subgraph, subgraphOpts := range r.subgraphTransportOptions.SubgraphMap { subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph) subgraphTransports[subgraph] = subgraphBaseTransport } ctx, cancel := context.WithCancel(ctx) s := &graphServer{ context: ctx, cancelFunc: cancel, Config: &r.Config, engineStats: r.EngineStats, baseTransport: baseTransport, subgraphTransports: subgraphTransports, playgroundHandler: r.playgroundHandler, traceDialer: traceDialer, baseRouterConfigVersion: routerConfig.GetVersion(), inFlightRequests: &atomic.Uint64{}, graphMuxList: make([]*graphMux, 0, 1), instanceData: InstanceData{ HostName: r.hostName, ListenAddress: r.listenAddr, }, storageProviders: &r.storageProviders, } baseOtelAttributes := []attribute.KeyValue{ otel.WgRouterVersion.String(Version), otel.WgRouterClusterName.String(r.clusterName), } if s.graphApiToken != "" { claims, err := rjwt.ExtractFederatedGraphTokenClaims(s.graphApiToken) if err != nil { return nil, err } baseOtelAttributes = append(baseOtelAttributes, otel.WgFederatedGraphID.String(claims.FederatedGraphID)) } s.baseOtelAttributes = baseOtelAttributes baseDefaultMuxAttributes := append([]attribute.KeyValue{otel.WgRouterConfigVersion.String(s.baseRouterConfigVersion)}, baseOtelAttributes...) mapper := newAttributeMapper(!rmetric.IsUsingDefaultCloudExporter(s.metricConfig), s.metricConfig.Attributes) mappedMetricAttributes := mapper.mapAttributes(baseDefaultMuxAttributes) if s.metricConfig.OpenTelemetry.RouterRuntime { // We track runtime metrics with base router config version s.runtimeMetrics = rmetric.NewRuntimeMetrics( s.logger, s.otlpMeterProvider, mappedMetricAttributes, s.processStartTime, ) // Start runtime metrics if err := s.runtimeMetrics.Start(); err != nil { return nil, err } } if isConnStoreEnabled { connStore, err := rmetric.NewConnectionMetricStore( s.logger, nil, s.otlpMeterProvider, s.promMeterProvider, s.metricConfig, s.traceDialer.connectionPoolStats, ) if err != nil { return nil, err } s.connectionMetrics = connStore } if err := s.setupEngineStatistics(mappedMetricAttributes); err != nil { return nil, fmt.Errorf("failed to setup engine statistics: %w", err) } if s.registrationInfo != nil { publicKey, err := jwt.ParseECPublicKeyFromPEM([]byte(s.registrationInfo.GetGraphPublicKey())) if err != nil { return nil, fmt.Errorf("failed to parse router public key: %w", err) } s.publicKey = publicKey } httpRouter := chi.NewRouter() /** * Middlewares */ // This recovery handler is used for everything before the graph mux to ensure that // we can recover from panics and log them properly. httpRouter.Use(recoveryhandler.New(recoveryhandler.WithLogHandler(func(w http.ResponseWriter, r *http.Request, err any) { s.logger.Error("[Recovery from panic]", zap.Any("error", err), ) }))) // Request traffic shaping related middlewares httpRouter.Use(rmiddleware.RequestSize(int64(s.routerTrafficConfig.MaxRequestBodyBytes))) if s.routerTrafficConfig.DecompressionEnabled { httpRouter.Use(rmiddleware.HandleCompression(s.logger)) } httpRouter.Use(middleware.RequestID) httpRouter.Use(middleware.RealIP) if s.corsOptions.Enabled { httpRouter.Use(cors.New(*s.corsOptions)) } if s.subgraphCircuitBreakerOptions.IsEnabled() { manager, err := circuit.NewManager(s.subgraphCircuitBreakerOptions.CircuitBreaker) if err != nil { return nil, err } s.circuitBreakerManager = manager } routingUrlGroupings, err := getRoutingUrlGroupingForCircuitBreakers(routerConfig, s.overrideRoutingURLConfiguration, s.overrides) if err != nil { return nil, err } gm, err := s.buildGraphMux(ctx, BuildGraphMuxOptions{ RouterConfigVersion: s.baseRouterConfigVersion, EngineConfig: routerConfig.GetEngineConfig(), ConfigSubgraphs: routerConfig.GetSubgraphs(), RoutingUrlGroupings: routingUrlGroupings, }) if err != nil { return nil, fmt.Errorf("failed to build base mux: %w", err) } featureFlagConfigMap := routerConfig.FeatureFlagConfigs.GetConfigByFeatureFlagName() if len(featureFlagConfigMap) > 0 { s.logger.Info("Feature flags enabled", zap.Strings("flags", maps.Keys(featureFlagConfigMap))) } multiGraphHandler, err := s.buildMultiGraphHandler(ctx, gm.mux, featureFlagConfigMap) if err != nil { return nil, fmt.Errorf("failed to build feature flag handler: %w", err) } wrapper, err := gzhttp.NewWrapper( gzhttp.MinSize(1024*4), // 4KB gzhttp.CompressionLevel(gzip.DefaultCompression), gzhttp.ContentTypes(CompressibleContentTypes), ) if err != nil { return nil, fmt.Errorf("failed to create gzip wrapper: %w", err) } if s.traceConfig.Enabled { handler := rtrace.NewTracingHandler(rtrace.TracingHandlerOpts{ TraceConfig: s.traceConfig, HealthCheckPath: s.healthCheckPath, ReadinessCheckPath: s.readinessCheckPath, LivenessCheckPath: s.livenessCheckPath, CompositePropagator: s.compositePropagator, TracerProvider: s.tracerProvider, SpanNameFormatter: SpanNameFormatter, }) httpRouter.Use(handler) } if s.batchingConfig.Enabled { if s.batchingConfig.MaxConcurrentRoutines <= 0 { return nil, errors.New("maxConcurrent must be greater than 0") } if s.batchingConfig.MaxEntriesPerBatch <= 0 { return nil, errors.New("maxEntriesPerBatch must be greater than 0") } } /** * A group where we can selectively apply middlewares to the graphql endpoint */ httpRouter.Group(func(cr chi.Router) { // We are applying it conditionally because compressing 3MB playground is still slow even with stdlib gzip cr.Use(func(h http.Handler) http.Handler { return wrapper(h) }) if s.headerRules != nil { cr.Use(rmiddleware.CookieWhitelist(s.headerRules.CookieWhitelist, []string{featureFlagCookie})) } // Mount the feature flag handler. It calls the base mux if no feature flag is set. if s.batchingConfig.Enabled { handler := Handler( HandlerOpts{ MaxEntriesPerBatch: s.batchingConfig.MaxEntriesPerBatch, MaxRoutines: s.batchingConfig.MaxConcurrentRoutines, OmitExtensions: s.batchingConfig.OmitExtensions, HandlerSent: multiGraphHandler, Tracer: r.tracerProvider.Tracer( "wundergraph/cosmo/router/internal/batch", oteltrace.WithInstrumentationVersion("0.0.1"), ), Digest: xxhash.New(), ClientHeader: s.clientHeader, BaseOtelAttributes: s.baseOtelAttributes, RouterConfigVersion: s.baseRouterConfigVersion, Logger: s.logger, }, ) cr.Handle(r.graphqlPath, handler) } else { cr.Handle(r.graphqlPath, multiGraphHandler) } if r.webSocketConfiguration != nil && r.webSocketConfiguration.Enabled && r.webSocketConfiguration.AbsintheProtocol.Enabled { // Mount the Absinthe protocol handler for WebSockets httpRouter.Handle(r.webSocketConfiguration.AbsintheProtocol.HandlerPath, multiGraphHandler) } }) /** * Routes */ // We mount the playground once here when we don't have a conflict with the websocket handler // If we have a conflict, we mount the playground during building the individual muxes if s.playgroundHandler != nil && s.graphqlPath != s.playgroundConfig.Path { httpRouter.Get(r.playgroundConfig.Path, s.playgroundHandler(nil).ServeHTTP) } httpRouter.Get(s.healthCheckPath, r.healthcheck.Liveness()) httpRouter.Get(s.livenessCheckPath, r.healthcheck.Liveness()) httpRouter.Get(s.readinessCheckPath, r.healthcheck.Readiness()) s.mux = httpRouter return s, nil } func getRoutingUrlGroupingForCircuitBreakers( routerConfig *nodev1.RouterConfig, overrideRoutingURLConfiguration config.OverrideRoutingURLConfiguration, overridesConfiguration config.OverridesConfiguration, ) (map[string]map[string]bool, error) { routingUrlGroupings := make(map[string]map[string]bool) overwrites, err := configureSubgraphOverwrites( routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs(), overrideRoutingURLConfiguration, overridesConfiguration, true, ) if err != nil { return nil, err } for _, subgraph := range overwrites { if _, ok := routingUrlGroupings[subgraph.UrlString]; !ok { routingUrlGroupings[subgraph.UrlString] = make(map[string]bool) } routingUrlGroupings[subgraph.UrlString][subgraph.Name] = true } if routerConfig.FeatureFlagConfigs != nil { for _, ffConfig := range routerConfig.FeatureFlagConfigs.ConfigByFeatureFlagName { ffOverwrites, err := configureSubgraphOverwrites( ffConfig.GetEngineConfig(), ffConfig.GetSubgraphs(), overrideRoutingURLConfiguration, overridesConfiguration, true, ) if err != nil { return nil, err } for _, subgraph := range ffOverwrites { if _, ok := routingUrlGroupings[subgraph.UrlString]; !ok { routingUrlGroupings[subgraph.UrlString] = make(map[string]bool) } routingUrlGroupings[subgraph.UrlString][subgraph.Name] = true } } } return routingUrlGroupings, nil } func (s *graphServer) buildMultiGraphHandler( ctx context.Context, baseMux *chi.Mux, featureFlagConfigs map[string]*nodev1.FeatureFlagRouterExecutionConfig, ) (http.HandlerFunc, error) { if len(featureFlagConfigs) == 0 { return baseMux.ServeHTTP, nil } featureFlagToMux := make(map[string]*chi.Mux, len(featureFlagConfigs)) // Build all the muxes for the feature flags in serial to avoid any race conditions for featureFlagName, executionConfig := range featureFlagConfigs { gm, err := s.buildGraphMux(ctx, BuildGraphMuxOptions{ FeatureFlagName: featureFlagName, RouterConfigVersion: executionConfig.GetVersion(), EngineConfig: executionConfig.GetEngineConfig(), ConfigSubgraphs: executionConfig.Subgraphs, }) if err != nil { return nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err) } featureFlagToMux[featureFlagName] = gm.mux } return func(w http.ResponseWriter, r *http.Request) { // Extract the feature flag and run the corresponding mux // 1. From the request header // 2. From the cookie ff := strings.TrimSpace(r.Header.Get(featureFlagHeader)) if ff == "" { cookie, err := r.Cookie(featureFlagCookie) if err == nil && cookie != nil { ff = strings.TrimSpace(cookie.Value) } } if mux, ok := featureFlagToMux[ff]; ok { w.Header().Set(featureFlagHeader, ff) mux.ServeHTTP(w, r) return } baseMux.ServeHTTP(w, r) }, nil } // setupEngineStatistics creates the engine statistics for the server. // It creates the OTLP and Prometheus metrics for the engine statistics. func (s *graphServer) setupEngineStatistics(baseAttributes []attribute.KeyValue) (err error) { // We only include the base router config version in the attributes for the engine statistics. // Same approach is used for the runtime metrics. s.otlpEngineMetrics, err = rmetric.NewEngineMetrics( s.logger, baseAttributes, s.otlpMeterProvider, s.engineStats, &s.metricConfig.OpenTelemetry.EngineStats, ) if err != nil { return err } s.prometheusEngineMetrics, err = rmetric.NewEngineMetrics( s.logger, baseAttributes, s.promMeterProvider, s.engineStats, &s.metricConfig.Prometheus.EngineStats, ) if err != nil { return err } return nil } type graphMux struct { mux *chi.Mux planCache *ristretto.Cache[uint64, *planWithMetaData] persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry] normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry] validationCache *ristretto.Cache[uint64, bool] operationHashCache *ristretto.Cache[uint64, string] accessLogsFileLogger *logging.BufferedLogger metricStore rmetric.Store prometheusCacheMetrics *rmetric.CacheMetrics otelCacheMetrics *rmetric.CacheMetrics } // buildOperationCaches creates the caches for the graph mux. // The caches are created based on the engine configuration. func (s *graphMux) buildOperationCaches(srv *graphServer) (computeSha256 bool, err error) { // We create a new execution plan cache for each operation planner which is coupled to // the specific engine configuration. This is necessary because otherwise we would return invalid plans. // // when an execution plan was generated, which can be quite expensive, we want to cache it // this means that we can hash the input and cache the generated plan // the next time we get the same input, we can just return the cached plan // the engine is smart enough to first do normalization and then hash the input // this means that we can cache the normalized input and don't have to worry about // different inputs that would generate the same execution plan if srv.engineExecutionConfiguration.ExecutionPlanCacheSize > 0 { planCacheConfig := &ristretto.Config[uint64, *planWithMetaData]{ Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache, MaxCost: srv.engineExecutionConfiguration.ExecutionPlanCacheSize, NumCounters: srv.engineExecutionConfiguration.ExecutionPlanCacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, } s.planCache, err = ristretto.NewCache[uint64, *planWithMetaData](planCacheConfig) if err != nil { return computeSha256, fmt.Errorf("failed to create planner cache: %w", err) } } if srv.engineExecutionConfiguration.EnablePersistedOperationsCache || srv.automaticPersistedQueriesConfig.Enabled { cacheSize := int64(1024) persistedOperationCacheConfig := &ristretto.Config[uint64, NormalizationCacheEntry]{ MaxCost: cacheSize, NumCounters: cacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, Metrics: true, } s.persistedOperationCache, _ = ristretto.NewCache[uint64, NormalizationCacheEntry](persistedOperationCacheConfig) } if srv.engineExecutionConfiguration.EnableNormalizationCache && srv.engineExecutionConfiguration.NormalizationCacheSize > 0 { normalizationCacheConfig := &ristretto.Config[uint64, NormalizationCacheEntry]{ Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache, MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize, NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, } s.normalizationCache, err = ristretto.NewCache[uint64, NormalizationCacheEntry](normalizationCacheConfig) if err != nil { return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err) } } if srv.engineExecutionConfiguration.EnableValidationCache && srv.engineExecutionConfiguration.ValidationCacheSize > 0 { validationCacheConfig := &ristretto.Config[uint64, bool]{ Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache, MaxCost: srv.engineExecutionConfiguration.ValidationCacheSize, NumCounters: srv.engineExecutionConfiguration.ValidationCacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, } s.validationCache, err = ristretto.NewCache[uint64, bool](validationCacheConfig) if err != nil { return computeSha256, fmt.Errorf("failed to create validation cache: %w", err) } } if srv.securityConfiguration.ComplexityCalculationCache != nil && srv.securityConfiguration.ComplexityCalculationCache.Enabled && srv.securityConfiguration.ComplexityCalculationCache.CacheSize > 0 { complexityCalculationCacheConfig := &ristretto.Config[uint64, ComplexityCacheEntry]{ Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache, MaxCost: srv.securityConfiguration.ComplexityCalculationCache.CacheSize, NumCounters: srv.securityConfiguration.ComplexityCalculationCache.CacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, } s.complexityCalculationCache, err = ristretto.NewCache[uint64, ComplexityCacheEntry](complexityCalculationCacheConfig) if err != nil { return computeSha256, fmt.Errorf("failed to create query depth cache: %w", err) } } // Currently, we only support custom attributes from the context for OTLP metrics if len(srv.metricConfig.Attributes) > 0 { for _, customAttribute := range srv.metricConfig.Attributes { if customAttribute.ValueFrom != nil && customAttribute.ValueFrom.ContextField == ContextFieldOperationSha256 { computeSha256 = true break } } } else if srv.accessLogsConfig != nil { for _, customAttribute := range append(srv.accessLogsConfig.Attributes, srv.accessLogsConfig.SubgraphAttributes...) { if customAttribute.ValueFrom != nil && customAttribute.ValueFrom.ContextField == ContextFieldOperationSha256 { computeSha256 = true break } } } else if srv.persistedOperationsConfig.Safelist.Enabled || srv.persistedOperationsConfig.LogUnknown { // In these case, we'll want to compute the sha256 for every operation, in order to check that the operation // is present in the Persisted Operation cache computeSha256 = true } // Prometheus schema field usage metrics can use sha256, so we need to ensure it is computed if srv.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled && srv.metricConfig.Prometheus.PromSchemaFieldUsage.IncludeOperationSha { computeSha256 = true } if computeSha256 { operationHashCacheConfig := &ristretto.Config[uint64, string]{ MaxCost: srv.engineExecutionConfiguration.OperationHashCacheSize, NumCounters: srv.engineExecutionConfiguration.OperationHashCacheSize * 10, IgnoreInternalCost: true, BufferItems: 64, Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache, } s.operationHashCache, err = ristretto.NewCache[uint64, string](operationHashCacheConfig) if err != nil { return computeSha256, fmt.Errorf("failed to create operation hash cache: %w", err) } } return computeSha256, nil } // configureCacheMetrics sets up the cache metrics for this mux if enabled in the config. func (s *graphMux) configureCacheMetrics(srv *graphServer, baseOtelAttributes []attribute.KeyValue) error { if srv.metricConfig.OpenTelemetry.GraphqlCache { cacheMetrics, err := rmetric.NewCacheMetrics( srv.logger, baseOtelAttributes, srv.otlpMeterProvider) if err != nil { return fmt.Errorf("failed to create cache metrics for OTLP: %w", err) } s.otelCacheMetrics = cacheMetrics } if srv.metricConfig.Prometheus.GraphqlCache { cacheMetrics, err := rmetric.NewCacheMetrics( srv.logger, baseOtelAttributes, srv.promMeterProvider) if err != nil { return fmt.Errorf("failed to create cache metrics for Prometheus: %w", err) } s.prometheusCacheMetrics = cacheMetrics } var metricInfos []rmetric.CacheMetricInfo if s.planCache != nil { metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("plan", srv.engineExecutionConfiguration.ExecutionPlanCacheSize, s.planCache.Metrics)) } if s.normalizationCache != nil { metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("query_normalization", srv.engineExecutionConfiguration.NormalizationCacheSize, s.normalizationCache.Metrics)) } if s.persistedOperationCache != nil { metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("persisted_query_normalization", 1024, s.persistedOperationCache.Metrics)) } if s.validationCache != nil { metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("validation", srv.engineExecutionConfiguration.ValidationCacheSize, s.validationCache.Metrics)) } if s.operationHashCache != nil { metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("query_hash", srv.engineExecutionConfiguration.OperationHashCacheSize, s.operationHashCache.Metrics)) } if s.otelCacheMetrics != nil { if err := s.otelCacheMetrics.RegisterObservers(metricInfos); err != nil { return fmt.Errorf("failed to register observer for OTLP cache metrics: %w", err) } } if s.prometheusCacheMetrics != nil { if err := s.prometheusCacheMetrics.RegisterObservers(metricInfos); err != nil { return fmt.Errorf("failed to register observer for Prometheus cache metrics: %w", err) } } return nil } func (s *graphMux) Shutdown(ctx context.Context) error { var err error if s.planCache != nil { s.planCache.Close() } if s.persistedOperationCache != nil { s.persistedOperationCache.Close() } if s.normalizationCache != nil { s.normalizationCache.Close() } if s.complexityCalculationCache != nil { s.complexityCalculationCache.Close() } if s.validationCache != nil { s.validationCache.Close() } if s.operationHashCache != nil { s.operationHashCache.Close() } if s.accessLogsFileLogger != nil { if aErr := s.accessLogsFileLogger.Close(); aErr != nil { err = errors.Join(err, aErr) } } if s.otelCacheMetrics != nil { if aErr := s.otelCacheMetrics.Shutdown(); aErr != nil { err = errors.Join(err, aErr) } } if s.prometheusCacheMetrics != nil { if aErr := s.prometheusCacheMetrics.Shutdown(); aErr != nil { err = errors.Join(err, aErr) } } if s.metricStore != nil { if aErr := s.metricStore.Shutdown(ctx); aErr != nil { err = errors.Join(err, aErr) } } if err != nil { return fmt.Errorf("shutdown graph mux: %w", err) } return nil } // buildGraphMux creates a new graph mux with the given feature flags and engine configuration. // It also creates a new execution plan cache for the mux. The mux is not mounted on the server. // The mux is appended internally to the graph server's list of muxes to clean up later when the server is swapped. func (s *graphServer) buildGraphMux( ctx context.Context, opts BuildGraphMuxOptions, ) (*graphMux, error) { gm := &graphMux{ metricStore: rmetric.NewNoopMetrics(), } httpRouter := chi.NewRouter() // we only enable the attribute mapper if we are not using the default cloud exporter baseMuxAttributes := append([]attribute.KeyValue{otel.WgRouterConfigVersion.String(opts.RouterConfigVersion)}, s.baseOtelAttributes...) if !opts.IsBaseGraph() { baseMuxAttributes = append(baseMuxAttributes, otel.WgFeatureFlag.String(opts.FeatureFlagName)) } metricsEnabled := s.metricConfig.IsEnabled() exprManager := expr.CreateNewExprManager() // We might want to remap or exclude known attributes based on the configuration for metrics mapper := newAttributeMapper(!rmetric.IsUsingDefaultCloudExporter(s.metricConfig), s.metricConfig.Attributes) baseMetricAttributes := mapper.mapAttributes(baseMuxAttributes) metricAttExpressions, attErr := newAttributeExpressions(s.metricConfig.Attributes, exprManager) if attErr != nil { return nil, attErr } var telemetryAttExpressions *attributeExpressions if len(s.telemetryAttributes) > 0 { var telemetryAttErr error telemetryAttExpressions, telemetryAttErr = newAttributeExpressions(s.telemetryAttributes, exprManager) if telemetryAttErr != nil { return nil, telemetryAttErr } } var tracingAttExpressions *attributeExpressions if len(s.tracingAttributes) > 0 { var tracingAttrErr error tracingAttExpressions, tracingAttrErr = newAttributeExpressions(s.tracingAttributes, exprManager) if tracingAttrErr != nil { return nil, tracingAttrErr } } // Prometheus metricStore rely on OTLP metricStore if metricsEnabled { attrKeyValues := []attribute.KeyValue{ otel.WgRouterConfigVersion.String(opts.RouterConfigVersion), otel.WgRouterVersion.String(Version), } if !opts.IsBaseGraph() { attrKeyValues = append(attrKeyValues, otel.WgFeatureFlag.String(opts.FeatureFlagName)) } routerInfoBaseAttrs := otelmetric.WithAttributeSet(attribute.NewSet(attrKeyValues...)) // From a users perspective this is similar to engine metrics, etc // but in this case we use the same metric store otlpOpts := rmetric.MetricOpts{ EnableCircuitBreaker: s.metricConfig.OpenTelemetry.Enabled, } promOpts := rmetric.MetricOpts{ EnableCircuitBreaker: s.metricConfig.Prometheus.Enabled, } m, err := rmetric.NewStore(otlpOpts, promOpts, rmetric.WithPromMeterProvider(s.promMeterProvider), rmetric.WithOtlpMeterProvider(s.otlpMeterProvider), rmetric.WithBaseAttributes(baseMetricAttributes), rmetric.WithLogger(s.logger), rmetric.WithProcessStartTime(s.processStartTime), rmetric.WithCardinalityLimit(s.metricConfig.CardinalityLimit), rmetric.WithRouterInfoAttributes(routerInfoBaseAttrs), ) if err != nil { return nil, fmt.Errorf("failed to create metric handler: %w", err) } gm.metricStore = m } // We initialize circuit breakers for all subgraphs in the base configuration (non-ff) // so we don't duplicate circuit breakers for subgraphs and they can be used in the feature flags even // We initialize it in the buildGraphMux because we want to use the base metric configuration if opts.IsBaseGraph() && s.subgraphCircuitBreakerOptions.IsEnabled() { // If either otel or prom metrics are enabled for circuit breakers // we will enable circuit breaker metric collections isCircuitBreakerMetricsEnabled := s.metricConfig.OpenTelemetry.CircuitBreaker || s.metricConfig.Prometheus.CircuitBreaker err := s.circuitBreakerManager.Initialize(circuit.ManagerOpts{ SubgraphCircuitBreakers: s.subgraphCircuitBreakerOptions.SubgraphMap, MetricStore: gm.metricStore, UseMetrics: metricsEnabled && isCircuitBreakerMetricsEnabled, BaseOtelAttributes: baseMetricAttributes, AllGroupings: opts.RoutingUrlGroupings, }) if err != nil { return nil, err } } subgraphs, err := configureSubgraphOverwrites( opts.EngineConfig, opts.ConfigSubgraphs, s.overrideRoutingURLConfiguration, s.overrides, false, ) if err != nil { return nil, err } computeSha256, err := gm.buildOperationCaches(s) if err != nil { return nil, err } if err = gm.configureCacheMetrics(s, baseMetricAttributes); err != nil { return nil, err } metrics := NewRouterMetrics(&routerMetricsConfig{ metrics: gm.metricStore, gqlMetricsExporter: s.gqlMetricsExporter, exportEnabled: s.graphqlMetricsConfig.Enabled, routerConfigVersion: opts.RouterConfigVersion, logger: s.logger, promSchemaUsageEnabled: s.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled, promSchemaUsageIncludeOperationSha: s.metricConfig.Prometheus.PromSchemaFieldUsage.IncludeOperationSha, }) baseLogFields := []zapcore.Field{ zap.String("config_version", opts.RouterConfigVersion), } if !opts.IsBaseGraph() { baseLogFields = append(baseLogFields, zap.String("feature_flag", opts.FeatureFlagName)) } // Currently, we only support custom attributes from the context for OTLP metrics b := buildAttributesMap(s.metricConfig.Attributes) // Enrich the request context with the subgraph information which is required for custom modules and tracing subgraphResolver := NewSubgraphResolver(subgraphs) httpRouter.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(withSubgraphResolver(r.Context(), subgraphResolver)) requestLogger := s.logger.With(logging.WithRequestID(middleware.GetReqID(r.Context()))) // If this is a batched request attach id to the logger if batchedOperationId, ok := r.Context().Value(BatchedOperationId{}).(string); ok { requestLogger = requestLogger.With(logging.WithBatchedRequestOperationID(batchedOperationId)) } reqContext := buildRequestContext(requestContextOptions{ operationContext: nil, requestLogger: requestLogger, metricSetAttributes: b, metricsEnabled: metricsEnabled, traceEnabled: s.traceConfig.Enabled, mapper: mapper, metricAttributeExpressions: metricAttExpressions, telemetryAttributeExpressions: telemetryAttExpressions, tracingAttributeExpressions: tracingAttExpressions, w: w, r: r, }) r = r.WithContext(withRequestContext(r.Context(), reqContext)) // For debugging purposes, we can validate from what version of the config the request is coming from if s.setConfigVersionHeader { w.Header().Set("X-Router-Config-Version", opts.RouterConfigVersion) } h.ServeHTTP(w, r) }) }) var recoverOpts []recoveryhandler.Option // If we have no access logger configured, we log the panic in the recovery handler to avoid losing the panic information if s.accessLogsConfig == nil { recoverOpts = append(recoverOpts, recoveryhandler.WithLogHandler(func(w http.ResponseWriter, r *http.Request, err any) { reqContext := getRequestContext(r.Context()) if reqContext != nil { reqContext.logger.Error("[Recovery from panic]", zap.Any("error", err), ) } })) } recoveryHandler := recoveryhandler.New(recoverOpts...) httpRouter.Use(recoveryHandler) // Setup any router on request middlewares so that they can be used to manipulate // other downstream internal middlewares such as tracing or authentication httpRouter.Use(s.routerOnRequestHandlers...) /** * Initialize base attributes from headers and other sources */ var commonAttrRequestMapper func(r *http.Request) []attribute.KeyValue if len(s.telemetryAttributes) > 0 { // Common attributes across traces and metrics commonAttrRequestMapper = buildHeaderAttributesMapper(s.telemetryAttributes) } var tracingAttrRequestMapper func(r *http.Request) []attribute.KeyValue if len(s.tracingAttributes) > 0 { tracingAttrRequestMapper = buildHeaderAttributesMapper(s.tracingAttributes) } var metricAttrRequestMapper func(r *http.Request) []attribute.KeyValue if s.metricConfig.IsEnabled() { // Metric attributes are only used for OTLP metrics and Prometheus metrics metricAttrRequestMapper = buildHeaderAttributesMapper(s.metricConfig.Attributes) } httpRouter.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqContext := getRequestContext(r.Context()) reqContext.telemetry.addCommonTraceAttribute(baseMuxAttributes...) if commonAttrRequestMapper != nil { reqContext.telemetry.addCommonAttribute(commonAttrRequestMapper(r)...) } if tracingAttrRequestMapper != nil { reqContext.telemetry.addCommonTraceAttribute(tracingAttrRequestMapper(r)...) } if metricAttrRequestMapper != nil { reqContext.telemetry.addMetricAttribute(metricAttrRequestMapper(r)...) } h.ServeHTTP(w, r) }) }) if s.traceConfig.Enabled { f := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqContext := getRequestContext(r.Context()) traceID := rtrace.GetTraceID(r.Context()) requestLogger := reqContext.Logger().With(logging.WithTraceID(traceID)) reqContext.logger = requestLogger span := oteltrace.SpanFromContext(r.Context()) span.SetAttributes(reqContext.telemetry.traceAttrs...) // Set if the trace is sampled in the expression context isSampled := span.SpanContext().IsSampled() reqContext.expressionContext.Request.Trace.Sampled = isSampled // Set the trace ID in the response header if s.traceConfig.ResponseTraceHeader.Enabled { w.Header().Set(s.traceConfig.ResponseTraceHeader.HeaderName, traceID) } h.ServeHTTP(w, r) }) } httpRouter.Use(f) } var subgraphAccessLogger *requestlogger.SubgraphAccessLogger if s.accessLogsConfig != nil && s.accessLogsConfig.Logger != nil { exprAttributes, err := requestlogger.GetAccessLogConfigExpressions(s.accessLogsConfig.Attributes, exprManager) if err != nil { return nil, fmt.Errorf("failed building router access log expressions: %w", err) } accessLogAttributes := requestlogger.CleanupExpressionAttributes(s.accessLogsConfig.Attributes) requestLoggerOpts := []requestlogger.Option{ requestlogger.WithDefaultOptions(), requestlogger.WithNoTimeField(), requestlogger.WithFields(baseLogFields...), requestlogger.WithAttributes(accessLogAttributes), requestlogger.WithExprAttributes(exprAttributes), requestlogger.WithFieldsHandler(RouterAccessLogsFieldHandler), } var ipAnonConfig *requestlogger.IPAnonymizationConfig if s.ipAnonymization.Enabled { ipAnonConfig = &requestlogger.IPAnonymizationConfig{ Enabled: s.ipAnonymization.Enabled, Method: requestlogger.IPAnonymizationMethod(s.ipAnonymization.Method), } requestLoggerOpts = append(requestLoggerOpts, requestlogger.WithAnonymization(ipAnonConfig)) } requestLogger := requestlogger.New( s.accessLogsConfig.Logger, requestLoggerOpts..., ) httpRouter.Use(requestLogger) if s.accessLogsConfig.SubgraphEnabled { subgraphExprAttributes, err := requestlogger.GetAccessLogConfigExpressions(s.accessLogsConfig.SubgraphAttributes, exprManager) if err != nil { return nil, fmt.Errorf("failed building router access log expressions: %w", err) } subgraphAttributes := requestlogger.CleanupExpressionAttributes(s.accessLogsConfig.SubgraphAttributes) subgraphAccessLogger = requestlogger.NewSubgraphAccessLogger( s.accessLogsConfig.Logger, requestlogger.SubgraphOptions{ IPAnonymizationConfig: ipAnonConfig, FieldsHandler: SubgraphAccessLogsFieldHandler, Fields: baseLogFields, Attributes: subgraphAttributes, ExprAttributes: subgraphExprAttributes, }) } if exprManager.VisitorManager.IsResponseBodyUsedInExpressions() { httpRouter.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { buf := bytes.Buffer{} ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) ww.Tee(&buf) h.ServeHTTP(ww, r) reqContext := getRequestContext(r.Context()) reqContext.expressionContext.Response.Body.Raw = buf.String() }) }) } } routerEngineConfig := &RouterEngineConfiguration{ Execution: s.engineExecutionConfiguration, Headers: s.headerRules, Events: s.eventsConfig, SubgraphErrorPropagation: s.subgraphErrorPropagation, } // map[string]*http.Transport cannot be coerced into map[string]http.RoundTripper, unfortunately subgraphTippers := map[string]http.RoundTripper{} for subgraph, subgraphTransport := range s.subgraphTransports { subgraphTippers[subgraph] = subgraphTransport } if err := s.setupConnector(ctx, opts.EngineConfig, opts.ConfigSubgraphs); err != nil { return nil, fmt.Errorf("failed to setup plugin host: %w", err) } enableTraceClient := s.connectionMetrics != nil || exprManager.VisitorManager.IsSubgraphTraceUsedInExpressions() var baseConnMetricStore rmetric.ConnectionMetricStore = &rmetric.NoopConnectionMetricStore{} if s.connectionMetrics != nil { baseConnMetricStore = s.connectionMetrics } ecb := &ExecutorConfigurationBuilder{ introspection: s.introspection, baseURL: s.baseURL, baseTripper: s.baseTransport, subgraphTrippers: subgraphTippers, pluginHost: s.connector, logger: s.logger, trackUsageInfo: s.graphqlMetricsConfig.Enabled || s.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled, subscriptionClientOptions: &SubscriptionClientOptions{ PingInterval: s.engineExecutionConfiguration.WebSocketClientPingInterval, PingTimeout: s.engineExecutionConfiguration.WebSocketClientPingTimeout, ReadTimeout: s.engineExecutionConfiguration.WebSocketClientReadTimeout, FrameTimeout: s.engineExecutionConfiguration.WebSocketClientFrameTimeout, }, transportOptions: &TransportOptions{ SubgraphTransportOptions: s.subgraphTransportOptions, PreHandlers: s.preOriginHandlers, PostHandlers: s.postOriginHandlers, MetricStore: gm.metricStore, ConnectionMetricStore: baseConnMetricStore, RetryOptions: retrytransport.RetryOptions{ Enabled: s.retryOptions.Enabled, MaxRetryCount: s.retryOptions.MaxRetryCount, MaxDuration: s.retryOptions.MaxDuration, Interval: s.retryOptions.Interval, ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { return retrytransport.IsRetryableError(err, resp) && !isMutationRequest(req.Context()) }, }, TracerProvider: s.tracerProvider, TracePropagators: s.compositePropagator, LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker, Logger: s.logger, EnableTraceClient: enableTraceClient, CircuitBreaker: s.circuitBreakerManager, }, } executor, providers, err := ecb.Build( ctx, &ExecutorBuildOptions{ EngineConfig: opts.EngineConfig, Subgraphs: opts.ConfigSubgraphs, RouterEngineConfig: routerEngineConfig, Reporter: s.engineStats, ApolloCompatibilityFlags: s.apolloCompatibilityFlags, ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, HeartbeatInterval: s.multipartHeartbeatInterval, PluginsEnabled: s.plugins.Enabled, InstanceData: s.instanceData, }, ) if err != nil { return nil, fmt.Errorf("failed to build plan configuration: %w", err) } s.pubSubProviders = providers if pubSubStartupErr := s.startupPubSubProviders(ctx); pubSubStartupErr != nil { return nil, pubSubStartupErr } operationProcessor := NewOperationProcessor(OperationProcessorOptions{ Executor: executor, MaxOperationSizeInBytes: int64(s.routerTrafficConfig.MaxRequestBodyBytes), PersistedOperationClient: s.persistedOperationClient, AutomaticPersistedOperationCacheTtl: s.automaticPersistedQueriesConfig.Cache.TTL, EnablePersistedOperationsCache: s.engineExecutionConfiguration.EnablePersistedOperationsCache, PersistedOpsNormalizationCache: gm.persistedOperationCache, NormalizationCache: gm.normalizationCache, ValidationCache: gm.validationCache, QueryDepthCache: gm.complexityCalculationCache, OperationHashCache: gm.operationHashCache, ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize, IntrospectionEnabled: s.Config.introspection, ParserTokenizerLimits: astparser.TokenizerLimits{ MaxDepth: s.Config.securityConfiguration.ParserLimits.ApproximateDepthLimit, MaxFields: s.Config.securityConfiguration.ParserLimits.TotalFieldsLimit, }, ApolloCompatibilityFlags: s.apolloCompatibilityFlags, ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError, ComplexityLimits: s.securityConfiguration.ComplexityLimits, }) operationPlanner := NewOperationPlanner(executor, gm.planCache) // We support the MCP only on the base graph. Feature flags are not supported yet. if opts.IsBaseGraph() && s.mcpServer != nil { if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } } if s.Config.cacheWarmup != nil && s.Config.cacheWarmup.Enabled { if s.graphApiToken == "" { return nil, fmt.Errorf("graph token is required for cache warmup in order to communicate with the CDN") } processor := NewCacheWarmupPlanningProcessor(&CacheWarmupPlanningProcessorOptions{ OperationProcessor: operationProcessor, OperationPlanner: operationPlanner, ComplexityLimits: s.securityConfiguration.ComplexityLimits, RouterSchema: executor.RouterSchema, TrackSchemaUsage: s.graphqlMetricsConfig.Enabled, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, }) warmupConfig := &CacheWarmupConfig{ Log: s.logger, Processor: processor, Workers: s.Config.cacheWarmup.Workers, ItemsPerSecond: s.Config.cacheWarmup.ItemsPerSecond, Timeout: s.Config.cacheWarmup.Timeout, } warmupConfig.AfterOperation = func(item *CacheWarmupOperationPlanResult) { gm.metricStore.MeasureOperationPlanningTime(ctx, item.PlanningTime, nil, otelmetric.WithAttributes( append([]attribute.KeyValue{ otel.WgOperationName.String(item.OperationName), otel.WgClientName.String(item.ClientName), otel.WgClientVersion.String(item.ClientVersion), otel.WgFeatureFlag.String(opts.FeatureFlagName), otel.WgOperationHash.String(item.OperationHash), otel.WgOperationType.String(item.OperationType), otel.WgEnginePlanCacheHit.Bool(false), }, baseMetricAttributes...)..., ), ) } if s.Config.cacheWarmup.Source.Filesystem != nil { warmupConfig.Source = NewFileSystemSource(&FileSystemSourceConfig{ RootPath: s.Config.cacheWarmup.Source.Filesystem.Path, }) } else { cdnSource, err := NewCDNSource(s.Config.cdnConfig.URL, s.graphApiToken, s.logger) if err != nil { return nil, fmt.Errorf("failed to create cdn source: %w", err) } warmupConfig.Source = cdnSource } err = WarmupCaches(ctx, warmupConfig) if err != nil { // We don't want to fail the server if the cache warmup fails s.logger.Error("Failed to warmup caches. It will retry after server restart or graph execution config update", zap.Error(err)) } } authorizerOptions := &CosmoAuthorizerOptions{ FieldConfigurations: opts.EngineConfig.FieldConfigurations, RejectOperationIfUnauthorized: false, } if s.Config.authorization != nil { authorizerOptions.RejectOperationIfUnauthorized = s.authorization.RejectOperationIfUnauthorized } handlerOpts := HandlerOptions{ Executor: executor, Log: s.logger, EnableExecutionPlanCacheResponseHeader: s.engineExecutionConfiguration.EnableExecutionPlanCacheResponseHeader, EnablePersistedOperationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnablePersistedOperationsCacheResponseHeader, EnableNormalizationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnableNormalizationCacheResponseHeader, EnableResponseHeaderPropagation: s.headerRules != nil, EngineStats: s.engineStats, TracerProvider: s.tracerProvider, Authorizer: NewCosmoAuthorizer(authorizerOptions), SubgraphErrorPropagation: s.subgraphErrorPropagation, EngineLoaderHooks: NewEngineRequestHooks( gm.metricStore, subgraphAccessLogger, s.tracerProvider, tracingAttExpressions, telemetryAttExpressions, metricAttExpressions, exprManager.VisitorManager.IsSubgraphResponseBodyUsedInExpressions(), ), } if s.redisClient != nil { handlerOpts.RateLimitConfig = s.rateLimit handlerOpts.RateLimiter, err = NewCosmoRateLimiter(&CosmoRateLimiterOptions{ RedisClient: s.redisClient, Debug: s.rateLimit.Debug, RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode, KeySuffixExpression: s.rateLimit.KeySuffixExpression, ExprManager: exprManager, }) if err != nil { return nil, fmt.Errorf("failed to create rate limiter: %w", err) } } if s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled { handlerOpts.ApolloSubscriptionMultipartPrintBoundary = s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled } graphqlHandler := NewGraphQLHandler(handlerOpts) executor.Resolver.SetAsyncErrorWriter(graphqlHandler) operationBlocker, err := NewOperationBlocker(&OperationBlockerOptions{ BlockMutations: BlockMutationOptions{ Enabled: s.securityConfiguration.BlockMutations.Enabled, Condition: s.securityConfiguration.BlockMutations.Condition, }, BlockSubscriptions: BlockSubscriptionOptions{ Enabled: s.securityConfiguration.BlockSubscriptions.Enabled, Condition: s.securityConfiguration.BlockSubscriptions.Condition, }, BlockNonPersisted: BlockNonPersistedOptions{ Enabled: s.securityConfiguration.BlockNonPersistedOperations.Enabled, Condition: s.securityConfiguration.BlockNonPersistedOperations.Condition, }, PersistedOperationsDisabled: s.persistedOperationsConfig.Disabled, SafelistEnabled: s.persistedOperationsConfig.Safelist.Enabled, LogUnknownOperationsEnabled: s.persistedOperationsConfig.LogUnknown, exprManager: exprManager, }) if err != nil { return nil, fmt.Errorf("failed to create operation blocker: %w", err) } graphqlPreHandler := NewPreHandler(&PreHandlerOptions{ Logger: s.logger, Executor: executor, Metrics: metrics, OperationProcessor: operationProcessor, Planner: operationPlanner, AccessController: s.accessController, OperationBlocker: operationBlocker, RouterPublicKey: s.publicKey, EnableRequestTracing: s.engineExecutionConfiguration.EnableRequestTracing, DevelopmentMode: s.developmentMode, TracerProvider: s.tracerProvider, FlushTelemetryAfterResponse: s.awsLambda, TraceExportVariables: s.traceConfig.ExportGraphQLVariables.Enabled, FileUploadEnabled: s.fileUploadConfig.Enabled, MaxUploadFiles: s.fileUploadConfig.MaxFiles, MaxUploadFileSize: int(s.fileUploadConfig.MaxFileSizeBytes), ComplexityLimits: s.securityConfiguration.ComplexityLimits, AlwaysIncludeQueryPlan: s.engineExecutionConfiguration.Debug.AlwaysIncludeQueryPlan, AlwaysSkipLoader: s.engineExecutionConfiguration.Debug.AlwaysSkipLoader, QueryPlansEnabled: s.Config.queryPlansEnabled, QueryPlansLoggingEnabled: s.engineExecutionConfiguration.Debug.PrintQueryPlans, TrackSchemaUsageInfo: s.graphqlMetricsConfig.Enabled || s.metricConfig.Prometheus.PromSchemaFieldUsage.Enabled, ClientHeader: s.clientHeader, ComputeOperationSha256: computeSha256, ApolloCompatibilityFlags: &s.apolloCompatibilityFlags, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, ExprManager: exprManager, OmitBatchExtensions: s.batchingConfig.OmitExtensions, }) if s.webSocketConfiguration != nil && s.webSocketConfiguration.Enabled { wsMiddleware := NewWebsocketMiddleware(ctx, WebsocketMiddlewareOptions{ OperationProcessor: operationProcessor, OperationBlocker: operationBlocker, Planner: operationPlanner, GraphQLHandler: graphqlHandler, PreHandler: graphqlPreHandler, Metrics: metrics, AccessController: s.accessController, Logger: s.logger, Stats: s.engineStats, ReadTimeout: s.engineExecutionConfiguration.WebSocketClientReadTimeout, WriteTimeout: s.engineExecutionConfiguration.WebSocketClientWriteTimeout, EnableNetPoll: s.engineExecutionConfiguration.EnableNetPoll, NetPollTimeout: s.engineExecutionConfiguration.WebSocketClientPollTimeout, NetPollConnBufferSize: s.engineExecutionConfiguration.WebSocketClientConnBufferSize, WebSocketConfiguration: s.webSocketConfiguration, ClientHeader: s.clientHeader, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, ApolloCompatibilityFlags: s.apolloCompatibilityFlags, }) // When the playground path is equal to the graphql path, we need to handle // ws upgrades and html requests on the same route. if s.playgroundConfig.Enabled && s.graphqlPath == s.playgroundConfig.Path { httpRouter.Use(s.playgroundHandler, wsMiddleware) } else { httpRouter.Use(wsMiddleware) } } httpRouter.Use( // Responsible for handling regular GraphQL requests over HTTP not WebSockets graphqlPreHandler.Handler, // Must be mounted after the websocket middleware to ensure that we only count non-hijacked requests like WebSockets func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestContext := getRequestContext(r.Context()) // We don't want to count any type of subscriptions e.g. SSE as in-flight requests because they are long-lived if requestContext != nil && requestContext.operation != nil && requestContext.operation.opType != OperationTypeSubscription { s.inFlightRequests.Add(1) // Counting like this is safe because according to the go http.ServeHTTP documentation // the requests is guaranteed to be finished when ServeHTTP returns defer s.inFlightRequests.Sub(1) } handler.ServeHTTP(w, r) }) }) // Mount built global and custom modules // Needs to be mounted after the pre-handler to ensure that the request was parsed and authorized httpRouter.Use(s.routerMiddlewares...) // GraphQL over POST httpRouter.Post(s.graphqlPath, graphqlHandler.ServeHTTP) // GraphQL over GET httpRouter.Get(s.graphqlPath, graphqlHandler.ServeHTTP) gm.mux = httpRouter s.graphMuxListLock.Lock() defer s.graphMuxListLock.Unlock() s.graphMuxList = append(s.graphMuxList, gm) return gm, nil } func (s *graphServer) setupConnector(ctx context.Context, config *nodev1.EngineConfiguration, configSubgraphs []*nodev1.Subgraph) error { s.connector = grpcconnector.NewConnector() for _, dsConfig := range config.DatasourceConfigurations { grpcConfig := dsConfig.GetCustomGraphql().GetGrpc() if grpcConfig == nil { continue } var sg *nodev1.Subgraph for _, subgraph := range configSubgraphs { if subgraph.Id == dsConfig.Id { sg = subgraph break } } if sg == nil { return fmt.Errorf("subgraph %s not found", dsConfig.Id) } pluginConfig := grpcConfig.GetPlugin() if pluginConfig == nil { remoteProvider, err := grpcconnector.NewRemoteGRPCProvider(grpcconnector.RemoteGRPCProviderConfig{ Logger: s.logger, Name: sg.Name, Endpoint: sg.RoutingUrl, }) if err != nil { return fmt.Errorf("failed to create standalone plugin for subgraph %s: %w", dsConfig.Id, err) } err = s.connector.RegisterClientProvider(sg.Name, remoteProvider) if err != nil { return fmt.Errorf("failed to register standalone plugin: %w", err) } continue } if !s.plugins.Enabled { continue } basePath := "" if s.plugins.Path != "" { basePath = s.plugins.Path } pluginPath, err := filepath.Abs(filepath.Join(basePath, pluginConfig.GetName(), "bin", fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH))) if err != nil { return fmt.Errorf("failed to get plugin path: %w", err) } grpcPlugin, err := grpcconnector.NewGRPCPlugin(grpcconnector.GRPCPluginConfig{ Logger: s.logger, PluginName: pluginConfig.GetName(), PluginPath: pluginPath, }) if err != nil { return fmt.Errorf("failed to create grpc plugin for subgraph %s: %w", dsConfig.Id, err) } err = s.connector.RegisterClientProvider(sg.Name, grpcPlugin) if err != nil { return fmt.Errorf("failed to register grpc plugin: %w", err) } } if err := s.connector.Run(ctx); err != nil { return fmt.Errorf("failed to run plugin host: %w", err) } return nil } // wait waits for all in-flight requests to finish. Similar to http.Server.Shutdown we wait in intervals + jitter // to make the shutdown process more efficient. func (s *graphServer) wait(ctx context.Context) error { b := backoff.New(500*time.Millisecond, time.Millisecond) defer b.Reset() timer := time.NewTimer(b.Duration()) defer timer.Stop() for { if s.inFlightRequests.Load() == 0 { return nil } select { case <-ctx.Done(): return ctx.Err() case <-timer.C: timer.Reset(b.Duration()) } } } // Shutdown gracefully shutdown the server and waits for all in-flight requests to finish. // After all requests are done, it will shut down the metric store and runtime metrics. // Shutdown does cancel the context after all non-hijacked requests such as WebSockets has been handled. func (s *graphServer) Shutdown(ctx context.Context) error { // Cancel the context after the graceful shutdown is done // to clean up resources like websocket connections, pools, etc. defer s.cancelFunc() s.logger.Debug("Shutdown of graph server initiated. Waiting for in-flight requests to finish.", zap.String("config_version", s.baseRouterConfigVersion), ) var finalErr error // Wait for all in-flight requests to finish. // In the worst case, we wait until the context is done or all requests has timed out. if err := s.wait(ctx); err != nil { finalErr = errors.Join(finalErr, fmt.Errorf("failed to wait for in-flight requests: %w", err)) } s.logger.Debug("Shutdown of graph server resources", zap.String("grace_period", s.routerGracePeriod.String()), zap.String("config_version", s.baseRouterConfigVersion), ) // Ensure that we don't wait indefinitely for shutdown if s.routerGracePeriod > 0 { newCtx, cancel := context.WithTimeout(ctx, s.routerGracePeriod) defer cancel() ctx = newCtx } if s.runtimeMetrics != nil { if err := s.runtimeMetrics.Shutdown(); err != nil { finalErr = errors.Join(finalErr, err) } } if s.connectionMetrics != nil { if aErr := s.connectionMetrics.Shutdown(ctx); aErr != nil { finalErr = errors.Join(finalErr, aErr) } } if s.otlpEngineMetrics != nil { if err := s.otlpEngineMetrics.Shutdown(); err != nil { finalErr = errors.Join(finalErr, err) } } if s.prometheusEngineMetrics != nil { if err := s.prometheusEngineMetrics.Shutdown(); err != nil { finalErr = errors.Join(finalErr, err) } } if err := s.shutdownPubSubProviders(ctx); err != nil { finalErr = errors.Join(finalErr, err) } // Shutdown all graphs muxes to release resources // e.g. planner cache s.graphMuxListLock.Lock() defer s.graphMuxListLock.Unlock() for _, mux := range s.graphMuxList { if err := mux.Shutdown(ctx); err != nil { finalErr = errors.Join(finalErr, err) } } // Close idle connections on base and subgraph transports s.baseTransport.CloseIdleConnections() for _, subgraphTransport := range s.subgraphTransports { subgraphTransport.CloseIdleConnections() } if s.connector != nil { s.logger.Debug("Stopping old plugins") if err := s.connector.StopAllProviders(); err != nil { finalErr = errors.Join(finalErr, err) } } return finalErr } // startupPubSubProviders starts all pubsub providers // It returns an error if any of the providers fail to start // or if some providers takes to long to start func (s *graphServer) startupPubSubProviders(ctx context.Context) error { // Default timeout for pubsub provider startup const defaultStartupTimeout = 5 * time.Second return s.providersActionWithTimeout(ctx, func(ctx context.Context, provider datasource.Provider) error { return provider.Startup(ctx) }, defaultStartupTimeout, "pubsub provider startup timed out") } // shutdownPubSubProviders shuts down all pubsub providers // It returns an error if any of the providers fail to shutdown // or if some providers takes to long to shutdown func (s *graphServer) shutdownPubSubProviders(ctx context.Context) error { // Default timeout for pubsub provider shutdown const defaultShutdownTimeout = 5 * time.Second return s.providersActionWithTimeout(ctx, func(ctx context.Context, provider datasource.Provider) error { return provider.Shutdown(ctx) }, defaultShutdownTimeout, "pubsub provider shutdown timed out") } func (s *graphServer) providersActionWithTimeout(ctx context.Context, action func(ctx context.Context, provider datasource.Provider) error, timeout time.Duration, timeoutMessage string) error { cancellableCtx, cancel := context.WithCancel(ctx) defer cancel() timer := time.NewTimer(timeout) defer timer.Stop() providersGroup := new(errgroup.Group) for _, provider := range s.pubSubProviders { providersGroup.Go(func() error { actionDone := make(chan error, 1) go func() { actionDone <- action(cancellableCtx, provider) }() select { case err := <-actionDone: return err case <-timer.C: return errors.New(timeoutMessage) } }) } return providersGroup.Wait() } func configureSubgraphOverwrites( engineConfig *nodev1.EngineConfiguration, configSubgraphs []*nodev1.Subgraph, overrideRoutingURLConfig config.OverrideRoutingURLConfiguration, overridesConfig config.OverridesConfiguration, skipOverrides bool, ) ([]Subgraph, error) { var err error subgraphs := make([]Subgraph, 0, len(configSubgraphs)) for _, sg := range configSubgraphs { subgraph := Subgraph{ Id: sg.Id, Name: sg.Name, } // Validate subgraph url. Note that it can be empty if the subgraph is virtual subgraph.Url, err = url.Parse(sg.RoutingUrl) if err != nil { return nil, fmt.Errorf("failed to parse subgraph url '%s': %w", sg.RoutingUrl, err) } subgraph.UrlString = subgraph.Url.String() overrideURL, ok := overrideRoutingURLConfig.Subgraphs[sg.Name] overrideSubgraph, overrideSubgraphOk := overridesConfig.Subgraphs[sg.Name] var overrideSubscriptionURL string var overrideSubscriptionProtocol *common.GraphQLSubscriptionProtocol var overrideSubscriptionWebsocketSubprotocol *common.GraphQLWebsocketSubprotocol if overrideSubgraphOk { if overrideSubgraph.RoutingURL != "" { overrideURL = overrideSubgraph.RoutingURL } if overrideSubgraph.SubscriptionURL != "" { overrideSubscriptionURL = overrideSubgraph.SubscriptionURL _, err := url.Parse(overrideSubscriptionURL) if err != nil { return nil, fmt.Errorf("failed to parse override url '%s': %w", overrideSubscriptionURL, err) } } if overrideSubgraph.SubscriptionProtocol != "" { switch overrideSubgraph.SubscriptionProtocol { case "ws": overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_WS.Enum() case "sse": overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE.Enum() case "sse_post": overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE_POST.Enum() default: return nil, fmt.Errorf("invalid subscription protocol '%s'", overrideSubgraph.SubscriptionProtocol) } } if overrideSubgraph.SubscriptionWebsocketSubprotocol != "" { switch overrideSubgraph.SubscriptionWebsocketSubprotocol { case "graphql-ws": overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_WS.Enum() case "graphql-transport-ws": overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_TRANSPORT_WS.Enum() case "auto": overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_AUTO.Enum() default: return nil, fmt.Errorf("invalid subscription websocket subprotocol '%s'", overrideSubgraph.SubscriptionWebsocketSubprotocol) } } } // check if the subgraph is overridden if ok || overrideSubgraphOk { if overrideURL != "" { subgraph.Url, err = url.Parse(overrideURL) if err != nil { return nil, fmt.Errorf("failed to parse override url '%s': %w", overrideURL, err) } subgraph.UrlString = subgraph.Url.String() } // If skipOverrides is true we do not want to update the references and only care about // getting a subgraph result with the overridden url if !skipOverrides { // Override datasource urls for _, conf := range engineConfig.DatasourceConfigurations { if conf.Id == sg.Id { if overrideURL != "" { conf.CustomGraphql.Fetch.Url.StaticVariableContent = overrideURL sg.RoutingUrl = overrideURL } if overrideSubscriptionURL != "" { conf.CustomGraphql.Subscription.Url.StaticVariableContent = overrideSubscriptionURL } if overrideSubscriptionProtocol != nil { conf.CustomGraphql.Subscription.Protocol = overrideSubscriptionProtocol } if overrideSubscriptionWebsocketSubprotocol != nil { conf.CustomGraphql.Subscription.WebsocketSubprotocol = overrideSubscriptionWebsocketSubprotocol } break } } } } subgraphs = append(subgraphs, subgraph) } return subgraphs, nil }