router/core/graph_server.go (1,160 lines of code) (raw):
package core
import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/cloudflare/backoff"
"github.com/dgraph-io/ristretto/v2"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/golang-jwt/jwt/v5"
"github.com/klauspost/compress/gzhttp"
"github.com/klauspost/compress/gzip"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
otelmetric "go.opentelemetry.io/otel/metric"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
oteltrace "go.opentelemetry.io/otel/trace"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/exp/maps"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource"
"github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/common"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
rjwt "github.com/wundergraph/cosmo/router/internal/jwt"
rmiddleware "github.com/wundergraph/cosmo/router/internal/middleware"
"github.com/wundergraph/cosmo/router/internal/recoveryhandler"
"github.com/wundergraph/cosmo/router/internal/requestlogger"
"github.com/wundergraph/cosmo/router/internal/retrytransport"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/cors"
"github.com/wundergraph/cosmo/router/pkg/execution_config"
"github.com/wundergraph/cosmo/router/pkg/health"
"github.com/wundergraph/cosmo/router/pkg/logging"
rmetric "github.com/wundergraph/cosmo/router/pkg/metric"
"github.com/wundergraph/cosmo/router/pkg/otel"
"github.com/wundergraph/cosmo/router/pkg/pubsub"
"github.com/wundergraph/cosmo/router/pkg/pubsub/kafka"
pubsubNats "github.com/wundergraph/cosmo/router/pkg/pubsub/nats"
"github.com/wundergraph/cosmo/router/pkg/statistics"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
)
const (
featureFlagHeader = "X-Feature-Flag"
featureFlagCookie = "feature_flag"
)
type (
// Server is the public interface of the server.
Server interface {
HttpServer() *http.Server
HealthChecks() health.Checker
}
EnginePubSubProviders struct {
nats map[string]pubsub_datasource.NatsPubSub
kafka map[string]pubsub_datasource.KafkaPubSub
}
// graphServer is the swappable implementation of a Graph instance which is an HTTP mux with middlewares.
// Everytime a schema is updated, the old graph server is shutdown and a new graph server is created.
// For feature flags, a graphql server has multiple mux and is dynamically switched based on the feature flag header or cookie.
// All fields are shared between all feature muxes. On shutdown, all graph instances are shutdown.
graphServer struct {
*Config
context context.Context
cancelFunc context.CancelFunc
pubSubProviders *EnginePubSubProviders
engineStats statistics.EngineStatistics
playgroundHandler func(http.Handler) http.Handler
publicKey *ecdsa.PublicKey
executionTransport *http.Transport
executionTransportProxy ProxyFunc
baseOtelAttributes []attribute.KeyValue
baseRouterConfigVersion string
mux *chi.Mux
// inFlightRequests is used to track the number of requests currently being processed
// does not include websocket (hijacked) connections
inFlightRequests *atomic.Uint64
graphMuxList []*graphMux
graphMuxListLock sync.Mutex
runtimeMetrics *rmetric.RuntimeMetrics
otlpEngineMetrics *rmetric.EngineMetrics
prometheusEngineMetrics *rmetric.EngineMetrics
hostName string
routerListenAddr string
}
)
// newGraphServer creates a new server instance.
func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterConfig, proxy ProxyFunc) (*graphServer, error) {
/* Older versions of composition will not populate a compatibility version.
* Currently, all "old" router execution configurations are compatible as there have been no breaking
* changes.
* Upon the first breaking change to the execution config, an unpopulated compatibility version will
* also be unsupported (and the logic for IsRouterCompatibleWithExecutionConfig will need to be updated).
*/
if !execution_config.IsRouterCompatibleWithExecutionConfig(r.logger, routerConfig.CompatibilityVersion) {
return nil, fmt.Errorf(`the compatibility version "%s" is not compatible with this router version`, routerConfig.CompatibilityVersion)
}
ctx, cancel := context.WithCancel(ctx)
s := &graphServer{
context: ctx,
cancelFunc: cancel,
Config: &r.Config,
engineStats: r.EngineStats,
executionTransport: newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy),
executionTransportProxy: proxy,
playgroundHandler: r.playgroundHandler,
baseRouterConfigVersion: routerConfig.GetVersion(),
inFlightRequests: &atomic.Uint64{},
graphMuxList: make([]*graphMux, 0, 1),
routerListenAddr: r.listenAddr,
hostName: r.hostName,
pubSubProviders: &EnginePubSubProviders{
nats: map[string]pubsub_datasource.NatsPubSub{},
kafka: map[string]pubsub_datasource.KafkaPubSub{},
},
}
baseOtelAttributes := []attribute.KeyValue{
otel.WgRouterVersion.String(Version),
otel.WgRouterClusterName.String(r.clusterName),
}
if s.graphApiToken != "" {
claims, err := rjwt.ExtractFederatedGraphTokenClaims(s.graphApiToken)
if err != nil {
return nil, err
}
baseOtelAttributes = append(baseOtelAttributes, otel.WgFederatedGraphID.String(claims.FederatedGraphID))
}
s.baseOtelAttributes = baseOtelAttributes
if s.metricConfig.OpenTelemetry.RouterRuntime {
s.runtimeMetrics = rmetric.NewRuntimeMetrics(
s.logger,
s.otlpMeterProvider,
// We track runtime metrics with base router config version
append([]attribute.KeyValue{
otel.WgRouterConfigVersion.String(s.baseRouterConfigVersion),
}, baseOtelAttributes...),
s.processStartTime,
)
// Start runtime metrics
if err := s.runtimeMetrics.Start(); err != nil {
return nil, err
}
}
if err := s.setupEngineStatistics(); err != nil {
return nil, fmt.Errorf("failed to setup engine statistics: %w", err)
}
if s.registrationInfo != nil {
publicKey, err := jwt.ParseECPublicKeyFromPEM([]byte(s.registrationInfo.GetGraphPublicKey()))
if err != nil {
return nil, fmt.Errorf("failed to parse router public key: %w", err)
}
s.publicKey = publicKey
}
httpRouter := chi.NewRouter()
/**
* Middlewares
*/
// This recovery handler is used for everything before the graph mux to ensure that
// we can recover from panics and log them properly.
httpRouter.Use(recoveryhandler.New(recoveryhandler.WithLogHandler(func(w http.ResponseWriter, r *http.Request, err any) {
s.logger.Error("[Recovery from panic]",
zap.Any("error", err),
)
})))
// Request traffic shaping related middlewares
httpRouter.Use(rmiddleware.RequestSize(int64(s.routerTrafficConfig.MaxRequestBodyBytes)))
if s.routerTrafficConfig.DecompressionEnabled {
httpRouter.Use(rmiddleware.HandleCompression(s.logger))
}
httpRouter.Use(middleware.RequestID)
httpRouter.Use(middleware.RealIP)
if s.corsOptions.Enabled {
httpRouter.Use(cors.New(*s.corsOptions))
}
gm, err := s.buildGraphMux(ctx, "", s.baseRouterConfigVersion, routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs())
if err != nil {
return nil, fmt.Errorf("failed to build base mux: %w", err)
}
featureFlagConfigMap := routerConfig.FeatureFlagConfigs.GetConfigByFeatureFlagName()
if len(featureFlagConfigMap) > 0 {
s.logger.Info("Feature flags enabled", zap.Strings("flags", maps.Keys(featureFlagConfigMap)))
}
multiGraphHandler, err := s.buildMultiGraphHandler(ctx, gm.mux, featureFlagConfigMap)
if err != nil {
return nil, fmt.Errorf("failed to build feature flag handler: %w", err)
}
wrapper, err := gzhttp.NewWrapper(
gzhttp.MinSize(1024*4), // 4KB
gzhttp.CompressionLevel(gzip.DefaultCompression),
gzhttp.ContentTypes(CompressibleContentTypes),
)
if err != nil {
return nil, fmt.Errorf("failed to create gzip wrapper: %w", err)
}
/**
* A group where we can selectively apply middlewares to the graphql endpoint
*/
httpRouter.Group(func(cr chi.Router) {
// We are applying it conditionally because compressing 3MB playground is still slow even with stdlib gzip
cr.Use(func(h http.Handler) http.Handler {
return wrapper(h)
})
if s.headerRules != nil {
cr.Use(rmiddleware.CookieWhitelist(s.headerRules.CookieWhitelist, []string{featureFlagCookie}))
}
// Mount the feature flag handler. It calls the base mux if no feature flag is set.
cr.Handle(r.graphqlPath, multiGraphHandler)
if r.webSocketConfiguration != nil && r.webSocketConfiguration.Enabled && r.webSocketConfiguration.AbsintheProtocol.Enabled {
// Mount the Absinthe protocol handler for WebSockets
httpRouter.Handle(r.webSocketConfiguration.AbsintheProtocol.HandlerPath, multiGraphHandler)
}
})
/**
* Routes
*/
// We mount the playground once here when we don't have a conflict with the websocket handler
// If we have a conflict, we mount the playground during building the individual muxes
if s.playgroundHandler != nil && s.graphqlPath != s.playgroundConfig.Path {
httpRouter.Get(r.playgroundConfig.Path, s.playgroundHandler(nil).ServeHTTP)
}
httpRouter.Get(s.healthCheckPath, r.healthcheck.Liveness())
httpRouter.Get(s.livenessCheckPath, r.healthcheck.Liveness())
httpRouter.Get(s.readinessCheckPath, r.healthcheck.Readiness())
s.mux = httpRouter
return s, nil
}
func (s *graphServer) buildMultiGraphHandler(ctx context.Context, baseMux *chi.Mux, featureFlagConfigs map[string]*nodev1.FeatureFlagRouterExecutionConfig) (http.HandlerFunc, error) {
if len(featureFlagConfigs) == 0 {
return baseMux.ServeHTTP, nil
}
featureFlagToMux := make(map[string]*chi.Mux, len(featureFlagConfigs))
// Build all the muxes for the feature flags in serial to avoid any race conditions
for featureFlagName, executionConfig := range featureFlagConfigs {
gm, err := s.buildGraphMux(ctx,
featureFlagName,
executionConfig.GetVersion(),
executionConfig.GetEngineConfig(),
executionConfig.Subgraphs,
)
if err != nil {
return nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err)
}
featureFlagToMux[featureFlagName] = gm.mux
}
return func(w http.ResponseWriter, r *http.Request) {
// Extract the feature flag and run the corresponding mux
// 1. From the request header
// 2. From the cookie
ff := strings.TrimSpace(r.Header.Get(featureFlagHeader))
if ff == "" {
cookie, err := r.Cookie(featureFlagCookie)
if err == nil && cookie != nil {
ff = strings.TrimSpace(cookie.Value)
}
}
if mux, ok := featureFlagToMux[ff]; ok {
w.Header().Set(featureFlagHeader, ff)
mux.ServeHTTP(w, r)
return
}
// Fall back to the base composition
baseMux.ServeHTTP(w, r)
}, nil
}
// setupEngineStatistics creates the engine statistics for the server.
// It creates the OTLP and Prometheus metrics for the engine statistics.
func (s *graphServer) setupEngineStatistics() (err error) {
// We only include the base router config version in the attributes for the engine statistics.
// Same approach is used for the runtime metrics.
baseAttributes := append([]attribute.KeyValue{
otel.WgRouterConfigVersion.String(s.baseRouterConfigVersion),
}, s.baseOtelAttributes...)
s.otlpEngineMetrics, err = rmetric.NewEngineMetrics(
s.logger,
baseAttributes,
s.otlpMeterProvider,
s.engineStats,
&s.metricConfig.OpenTelemetry.EngineStats,
)
if err != nil {
return err
}
s.prometheusEngineMetrics, err = rmetric.NewEngineMetrics(
s.logger,
baseAttributes,
s.promMeterProvider,
s.engineStats,
&s.metricConfig.Prometheus.EngineStats,
)
if err != nil {
return err
}
return nil
}
type graphMux struct {
mux *chi.Mux
planCache *ristretto.Cache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
}
// buildOperationCaches creates the caches for the graph mux.
// The caches are created based on the engine configuration.
func (s *graphMux) buildOperationCaches(srv *graphServer) (computeSha256 bool, err error) {
// We create a new execution plan cache for each operation planner which is coupled to
// the specific engine configuration. This is necessary because otherwise we would return invalid plans.
//
// when an execution plan was generated, which can be quite expensive, we want to cache it
// this means that we can hash the input and cache the generated plan
// the next time we get the same input, we can just return the cached plan
// the engine is smart enough to first do normalization and then hash the input
// this means that we can cache the normalized input and don't have to worry about
// different inputs that would generate the same execution plan
if srv.engineExecutionConfiguration.ExecutionPlanCacheSize > 0 {
planCacheConfig := &ristretto.Config[uint64, *planWithMetaData]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.ExecutionPlanCacheSize,
NumCounters: srv.engineExecutionConfiguration.ExecutionPlanCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.planCache, err = ristretto.NewCache[uint64, *planWithMetaData](planCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create planner cache: %w", err)
}
}
if srv.engineExecutionConfiguration.EnablePersistedOperationsCache || srv.automaticPersistedQueriesConfig.Enabled {
cacheSize := int64(1024)
persistedOperationCacheConfig := &ristretto.Config[uint64, NormalizationCacheEntry]{
MaxCost: cacheSize,
NumCounters: cacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
Metrics: true,
}
s.persistedOperationCache, _ = ristretto.NewCache[uint64, NormalizationCacheEntry](persistedOperationCacheConfig)
}
if srv.engineExecutionConfiguration.EnableNormalizationCache && srv.engineExecutionConfiguration.NormalizationCacheSize > 0 {
normalizationCacheConfig := &ristretto.Config[uint64, NormalizationCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.NormalizationCacheSize,
NumCounters: srv.engineExecutionConfiguration.NormalizationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.normalizationCache, err = ristretto.NewCache[uint64, NormalizationCacheEntry](normalizationCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create normalization cache: %w", err)
}
}
if srv.engineExecutionConfiguration.EnableValidationCache && srv.engineExecutionConfiguration.ValidationCacheSize > 0 {
validationCacheConfig := &ristretto.Config[uint64, bool]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.engineExecutionConfiguration.ValidationCacheSize,
NumCounters: srv.engineExecutionConfiguration.ValidationCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.validationCache, err = ristretto.NewCache[uint64, bool](validationCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create validation cache: %w", err)
}
}
if srv.securityConfiguration.ComplexityCalculationCache != nil && srv.securityConfiguration.ComplexityCalculationCache.Enabled && srv.securityConfiguration.ComplexityCalculationCache.CacheSize > 0 {
complexityCalculationCacheConfig := &ristretto.Config[uint64, ComplexityCacheEntry]{
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
MaxCost: srv.securityConfiguration.ComplexityCalculationCache.CacheSize,
NumCounters: srv.securityConfiguration.ComplexityCalculationCache.CacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
}
s.complexityCalculationCache, err = ristretto.NewCache[uint64, ComplexityCacheEntry](complexityCalculationCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create query depth cache: %w", err)
}
}
// Currently, we only support custom attributes from the context for OTLP metrics
if len(srv.metricConfig.Attributes) > 0 {
for _, customAttribute := range srv.metricConfig.Attributes {
if customAttribute.ValueFrom != nil && customAttribute.ValueFrom.ContextField == ContextFieldOperationSha256 {
computeSha256 = true
break
}
}
} else if srv.accessLogsConfig != nil {
for _, customAttribute := range append(srv.accessLogsConfig.Attributes, srv.accessLogsConfig.SubgraphAttributes...) {
if customAttribute.ValueFrom != nil && customAttribute.ValueFrom.ContextField == ContextFieldOperationSha256 {
computeSha256 = true
break
}
}
} else if srv.persistedOperationsConfig.Safelist.Enabled || srv.persistedOperationsConfig.LogUnknown {
// In these case, we'll want to compute the sha256 for every operation, in order to check that the operation
// is present in the Persisted Operation cache
computeSha256 = true
}
if computeSha256 {
operationHashCacheConfig := &ristretto.Config[uint64, string]{
MaxCost: srv.engineExecutionConfiguration.OperationHashCacheSize,
NumCounters: srv.engineExecutionConfiguration.OperationHashCacheSize * 10,
IgnoreInternalCost: true,
BufferItems: 64,
Metrics: srv.metricConfig.OpenTelemetry.GraphqlCache || srv.metricConfig.Prometheus.GraphqlCache,
}
s.operationHashCache, err = ristretto.NewCache[uint64, string](operationHashCacheConfig)
if err != nil {
return computeSha256, fmt.Errorf("failed to create operation hash cache: %w", err)
}
}
return computeSha256, nil
}
// configureCacheMetrics sets up the cache metrics for this mux if enabled in the config.
func (s *graphMux) configureCacheMetrics(srv *graphServer, baseOtelAttributes []attribute.KeyValue) error {
if srv.metricConfig.OpenTelemetry.GraphqlCache {
cacheMetrics, err := rmetric.NewCacheMetrics(
srv.logger,
baseOtelAttributes,
srv.otlpMeterProvider)
if err != nil {
return fmt.Errorf("failed to create cache metrics for OTLP: %w", err)
}
s.otelCacheMetrics = cacheMetrics
}
if srv.metricConfig.Prometheus.GraphqlCache {
cacheMetrics, err := rmetric.NewCacheMetrics(
srv.logger,
baseOtelAttributes,
srv.promMeterProvider)
if err != nil {
return fmt.Errorf("failed to create cache metrics for Prometheus: %w", err)
}
s.prometheusCacheMetrics = cacheMetrics
}
var metricInfos []rmetric.CacheMetricInfo
if s.planCache != nil {
metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("plan", srv.engineExecutionConfiguration.ExecutionPlanCacheSize, s.planCache.Metrics))
}
if s.normalizationCache != nil {
metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("query_normalization", srv.engineExecutionConfiguration.NormalizationCacheSize, s.normalizationCache.Metrics))
}
if s.persistedOperationCache != nil {
metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("persisted_query_normalization", 1024, s.persistedOperationCache.Metrics))
}
if s.validationCache != nil {
metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("validation", srv.engineExecutionConfiguration.ValidationCacheSize, s.validationCache.Metrics))
}
if s.operationHashCache != nil {
metricInfos = append(metricInfos, rmetric.NewCacheMetricInfo("query_hash", srv.engineExecutionConfiguration.OperationHashCacheSize, s.operationHashCache.Metrics))
}
if s.otelCacheMetrics != nil {
if err := s.otelCacheMetrics.RegisterObservers(metricInfos); err != nil {
return fmt.Errorf("failed to register observer for OTLP cache metrics: %w", err)
}
}
if s.prometheusCacheMetrics != nil {
if err := s.prometheusCacheMetrics.RegisterObservers(metricInfos); err != nil {
return fmt.Errorf("failed to register observer for Prometheus cache metrics: %w", err)
}
}
return nil
}
func (s *graphMux) Shutdown(ctx context.Context) error {
var err error
if s.planCache != nil {
s.planCache.Close()
}
if s.persistedOperationCache != nil {
s.persistedOperationCache.Close()
}
if s.normalizationCache != nil {
s.normalizationCache.Close()
}
if s.validationCache != nil {
s.validationCache.Close()
}
if s.complexityCalculationCache != nil {
s.complexityCalculationCache.Close()
}
if s.accessLogsFileLogger != nil {
if aErr := s.accessLogsFileLogger.Close(); aErr != nil {
err = errors.Join(err, aErr)
}
}
if s.otelCacheMetrics != nil {
if aErr := s.otelCacheMetrics.Shutdown(); aErr != nil {
err = errors.Join(err, aErr)
}
}
if s.prometheusCacheMetrics != nil {
if aErr := s.prometheusCacheMetrics.Shutdown(); aErr != nil {
err = errors.Join(err, aErr)
}
}
if s.metricStore != nil {
if aErr := s.metricStore.Shutdown(ctx); aErr != nil {
err = errors.Join(err, aErr)
}
}
return err
}
// buildGraphMux creates a new graph mux with the given feature flags and engine configuration.
// It also creates a new execution plan cache for the mux. The mux is not mounted on the server.
// The mux is appended internally to the graph server's list of muxes to clean up later when the server is swapped.
func (s *graphServer) buildGraphMux(ctx context.Context,
featureFlagName string,
routerConfigVersion string,
engineConfig *nodev1.EngineConfiguration,
configSubgraphs []*nodev1.Subgraph,
) (*graphMux, error) {
gm := &graphMux{
metricStore: rmetric.NewNoopMetrics(),
}
httpRouter := chi.NewRouter()
baseOtelAttributes := append([]attribute.KeyValue{otel.WgRouterConfigVersion.String(routerConfigVersion)}, s.baseOtelAttributes...)
if featureFlagName != "" {
baseOtelAttributes = append(baseOtelAttributes, otel.WgFeatureFlag.String(featureFlagName))
}
metricsEnabled := s.metricConfig.IsEnabled()
// we only enable the attribute mapper if we are not using the default cloud exporter
enableAttributeMapper := !(s.metricConfig.IsUsingCloudExporter || rmetric.IsDefaultCloudExporterConfigured(s.metricConfig.OpenTelemetry.Exporters))
// We might want to remap or exclude known attributes based on the configuration for metrics
mapper := newAttributeMapper(enableAttributeMapper, s.metricConfig.Attributes)
attExpressions, attErr := newAttributeExpressions(s.metricConfig.Attributes)
if attErr != nil {
return nil, attErr
}
baseMetricAttributes := mapper.mapAttributes(baseOtelAttributes)
var telemetryAttExpressions *attributeExpressions
if len(s.telemetryAttributes) > 0 {
var telemetryAttErr error
telemetryAttExpressions, telemetryAttErr = newAttributeExpressions(s.telemetryAttributes)
if telemetryAttErr != nil {
return nil, telemetryAttErr
}
}
// Prometheus metricStore rely on OTLP metricStore
if metricsEnabled {
m, err := rmetric.NewStore(
rmetric.WithPromMeterProvider(s.promMeterProvider),
rmetric.WithOtlpMeterProvider(s.otlpMeterProvider),
rmetric.WithBaseAttributes(baseMetricAttributes),
rmetric.WithLogger(s.logger),
rmetric.WithProcessStartTime(s.processStartTime),
rmetric.WithCardinalityLimit(rmetric.DefaultCardinalityLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to create metric handler: %w", err)
}
gm.metricStore = m
}
subgraphs, err := configureSubgraphOverwrites(
engineConfig,
configSubgraphs,
s.overrideRoutingURLConfiguration,
s.overrides,
)
if err != nil {
return nil, err
}
computeSha256, err := gm.buildOperationCaches(s)
if err != nil {
return nil, err
}
if err = gm.configureCacheMetrics(s, baseMetricAttributes); err != nil {
return nil, err
}
metrics := NewRouterMetrics(&routerMetricsConfig{
metrics: gm.metricStore,
gqlMetricsExporter: s.gqlMetricsExporter,
exportEnabled: s.graphqlMetricsConfig.Enabled,
routerConfigVersion: routerConfigVersion,
logger: s.logger,
})
baseLogFields := []zapcore.Field{
zap.String("config_version", routerConfigVersion),
}
if featureFlagName != "" {
baseLogFields = append(baseLogFields, zap.String("feature_flag", featureFlagName))
}
// Currently, we only support custom attributes from the context for OTLP metrics
b := buildAttributesMap(s.metricConfig.Attributes)
// Enrich the request context with the subgraph information which is required for custom modules and tracing
subgraphResolver := NewSubgraphResolver(subgraphs)
httpRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestLogger := s.logger.With(logging.WithRequestID(middleware.GetReqID(r.Context())))
r = r.WithContext(withSubgraphResolver(r.Context(), subgraphResolver))
reqContext := buildRequestContext(requestContextOptions{
operationContext: nil,
requestLogger: requestLogger,
metricSetAttributes: b,
metricsEnabled: metricsEnabled,
traceEnabled: s.traceConfig.Enabled,
mapper: mapper,
metricAttributeExpressions: attExpressions,
telemetryAttributeExpressions: telemetryAttExpressions,
w: w,
r: r,
})
r = r.WithContext(withRequestContext(r.Context(), reqContext))
// For debugging purposes, we can validate from what version of the config the request is coming from
if s.setConfigVersionHeader {
w.Header().Set("X-Router-Config-Version", routerConfigVersion)
}
h.ServeHTTP(w, r)
})
})
var recoverOpts []recoveryhandler.Option
// If we have no access logger configured, we log the panic in the recovery handler to avoid losing the panic information
if s.accessLogsConfig == nil {
recoverOpts = append(recoverOpts, recoveryhandler.WithLogHandler(func(w http.ResponseWriter, r *http.Request, err any) {
reqContext := getRequestContext(r.Context())
if reqContext != nil {
reqContext.logger.Error("[Recovery from panic]",
zap.Any("error", err),
)
}
}))
}
recoveryHandler := recoveryhandler.New(recoverOpts...)
httpRouter.Use(recoveryHandler)
// Setup any router on request middlewares so that they can be used to manipulate
// other downstream internal middlewares such as tracing or authentication
httpRouter.Use(s.routerOnRequestHandlers...)
/**
* Initialize base attributes from headers and other sources
*/
var commonAttrRequestMapper func(r *http.Request) []attribute.KeyValue
if len(s.telemetryAttributes) > 0 {
// Common attributes across traces and metrics
commonAttrRequestMapper = buildHeaderAttributesMapper(s.telemetryAttributes)
}
var metricAttrRequestMapper func(r *http.Request) []attribute.KeyValue
// Metric attributes are only used for OTLP metrics and Prometheus metrics
if s.metricConfig.IsEnabled() {
metricAttrRequestMapper = buildHeaderAttributesMapper(s.metricConfig.Attributes)
}
httpRouter.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqContext := getRequestContext(r.Context())
reqContext.telemetry.addCommonTraceAttribute(baseOtelAttributes...)
reqContext.telemetry.addCommonTraceAttribute(otel.WgRouterConfigVersion.String(routerConfigVersion))
if commonAttrRequestMapper != nil {
reqContext.telemetry.addCommonAttribute(commonAttrRequestMapper(r)...)
}
if metricAttrRequestMapper != nil {
reqContext.telemetry.addMetricAttribute(metricAttrRequestMapper(r)...)
}
h.ServeHTTP(w, r)
})
})
if s.traceConfig.Enabled {
spanStartOptions := []oteltrace.SpanStartOption{
oteltrace.WithAttributes(
otel.RouterServerAttribute,
otel.WgRouterRootSpan.Bool(true),
),
}
if s.traceConfig.WithNewRoot {
spanStartOptions = append(spanStartOptions, oteltrace.WithNewRoot())
}
middlewareOptions := []otelhttp.Option{
otelhttp.WithSpanOptions(spanStartOptions...),
otelhttp.WithFilter(rtrace.CommonRequestFilter),
otelhttp.WithFilter(rtrace.PrefixRequestFilter(
[]string{s.healthCheckPath, s.readinessCheckPath, s.livenessCheckPath}),
),
// Disable built-in metricStore through NoopMeterProvider
otelhttp.WithMeterProvider(sdkmetric.NewMeterProvider()),
otelhttp.WithSpanNameFormatter(SpanNameFormatter),
otelhttp.WithTracerProvider(s.tracerProvider),
}
if s.compositePropagator != nil {
middlewareOptions = append(middlewareOptions, otelhttp.WithPropagators(s.compositePropagator))
}
traceHandler := rtrace.NewMiddleware(
rtrace.WithTracePreHandler(
func(r *http.Request, w http.ResponseWriter) {
reqContext := getRequestContext(r.Context())
traceID := rtrace.GetTraceID(r.Context())
requestLogger := reqContext.Logger().With(logging.WithTraceID(traceID))
reqContext.logger = requestLogger
span := oteltrace.SpanFromContext(r.Context())
span.SetAttributes(reqContext.telemetry.traceAttrs...)
// Set the trace ID in the response header
if s.traceConfig.ResponseTraceHeader.Enabled {
w.Header().Set(s.traceConfig.ResponseTraceHeader.HeaderName, traceID)
}
}),
rtrace.WithOtelHttp(middlewareOptions...),
)
httpRouter.Use(traceHandler.Handler)
}
var subgraphAccessLogger *requestlogger.SubgraphAccessLogger
if s.accessLogsConfig != nil && s.accessLogsConfig.Logger != nil {
exprAttributes, err := requestlogger.GetAccessLogConfigExpressions(s.accessLogsConfig.Attributes)
if err != nil {
return nil, fmt.Errorf("failed building router access log expressions: %w", err)
}
s.accessLogsConfig.Attributes = requestlogger.CleanupExpressionAttributes(s.accessLogsConfig.Attributes)
requestLoggerOpts := []requestlogger.Option{
requestlogger.WithDefaultOptions(),
requestlogger.WithNoTimeField(),
requestlogger.WithFields(baseLogFields...),
requestlogger.WithAttributes(s.accessLogsConfig.Attributes),
requestlogger.WithExprAttributes(exprAttributes),
requestlogger.WithFieldsHandler(RouterAccessLogsFieldHandler),
}
var ipAnonConfig *requestlogger.IPAnonymizationConfig
if s.ipAnonymization.Enabled {
ipAnonConfig = &requestlogger.IPAnonymizationConfig{
Enabled: s.ipAnonymization.Enabled,
Method: requestlogger.IPAnonymizationMethod(s.ipAnonymization.Method),
}
requestLoggerOpts = append(requestLoggerOpts, requestlogger.WithAnonymization(ipAnonConfig))
}
requestLogger := requestlogger.New(
s.accessLogsConfig.Logger,
requestLoggerOpts...,
)
httpRouter.Use(requestLogger)
if s.accessLogsConfig.SubgraphEnabled {
s.accessLogsConfig.SubgraphAttributes = requestlogger.CleanupExpressionAttributes(s.accessLogsConfig.SubgraphAttributes)
subgraphAccessLogger = requestlogger.NewSubgraphAccessLogger(
s.accessLogsConfig.Logger,
requestlogger.SubgraphOptions{
IPAnonymizationConfig: ipAnonConfig,
FieldsHandler: SubgraphAccessLogsFieldHandler,
Fields: baseLogFields,
Attributes: s.accessLogsConfig.SubgraphAttributes,
})
}
}
routerEngineConfig := &RouterEngineConfiguration{
Execution: s.engineExecutionConfiguration,
Headers: s.headerRules,
Events: s.eventsConfig,
SubgraphErrorPropagation: s.subgraphErrorPropagation,
}
err = s.buildPubSubConfiguration(ctx, engineConfig, routerEngineConfig)
if err != nil {
return nil, fmt.Errorf("failed to build pubsub configuration: %w", err)
}
ecb := &ExecutorConfigurationBuilder{
introspection: s.introspection,
baseURL: s.baseURL,
transport: s.executionTransport,
logger: s.logger,
trackUsageInfo: s.graphqlMetricsConfig.Enabled,
transportOptions: &TransportOptions{
Proxy: s.executionTransportProxy,
SubgraphTransportOptions: s.subgraphTransportOptions,
PreHandlers: s.preOriginHandlers,
PostHandlers: s.postOriginHandlers,
MetricStore: gm.metricStore,
RetryOptions: retrytransport.RetryOptions{
Enabled: s.retryOptions.Enabled,
MaxRetryCount: s.retryOptions.MaxRetryCount,
MaxDuration: s.retryOptions.MaxDuration,
Interval: s.retryOptions.Interval,
ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool {
return retrytransport.IsRetryableError(err, resp) && !isMutationRequest(req.Context())
},
},
TracerProvider: s.tracerProvider,
TracePropagators: s.compositePropagator,
LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker,
Logger: s.logger,
},
}
executor, err := ecb.Build(
ctx,
&ExecutorBuildOptions{
EngineConfig: engineConfig,
Subgraphs: configSubgraphs,
RouterEngineConfig: routerEngineConfig,
PubSubProviders: s.pubSubProviders,
Reporter: s.engineStats,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
HeartbeatInterval: s.multipartHeartbeatInterval,
},
)
if err != nil {
return nil, fmt.Errorf("failed to build plan configuration: %w", err)
}
operationProcessor := NewOperationProcessor(OperationProcessorOptions{
Executor: executor,
MaxOperationSizeInBytes: int64(s.routerTrafficConfig.MaxRequestBodyBytes),
PersistedOperationClient: s.persistedOperationClient,
AutomaticPersistedOperationCacheTtl: s.automaticPersistedQueriesConfig.Cache.TTL,
EnablePersistedOperationsCache: s.engineExecutionConfiguration.EnablePersistedOperationsCache,
PersistedOpsNormalizationCache: gm.persistedOperationCache,
NormalizationCache: gm.normalizationCache,
ValidationCache: gm.validationCache,
QueryDepthCache: gm.complexityCalculationCache,
OperationHashCache: gm.operationHashCache,
ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize,
IntrospectionEnabled: s.Config.introspection,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
})
operationPlanner := NewOperationPlanner(executor, gm.planCache)
if s.Config.cacheWarmup != nil && s.Config.cacheWarmup.Enabled {
if s.graphApiToken == "" {
return nil, fmt.Errorf("graph token is required for cache warmup in order to communicate with the CDN")
}
processor := NewCacheWarmupPlanningProcessor(&CacheWarmupPlanningProcessorOptions{
OperationProcessor: operationProcessor,
OperationPlanner: operationPlanner,
ComplexityLimits: s.securityConfiguration.ComplexityLimits,
RouterSchema: executor.RouterSchema,
TrackSchemaUsage: s.graphqlMetricsConfig.Enabled,
DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping,
})
warmupConfig := &CacheWarmupConfig{
Log: s.logger,
Processor: processor,
Workers: s.Config.cacheWarmup.Workers,
ItemsPerSecond: s.Config.cacheWarmup.ItemsPerSecond,
Timeout: s.Config.cacheWarmup.Timeout,
}
warmupConfig.AfterOperation = func(item *CacheWarmupOperationPlanResult) {
gm.metricStore.MeasureOperationPlanningTime(ctx,
item.PlanningTime,
nil,
otelmetric.WithAttributes(
append([]attribute.KeyValue{
otel.WgOperationName.String(item.OperationName),
otel.WgClientName.String(item.ClientName),
otel.WgClientVersion.String(item.ClientVersion),
otel.WgFeatureFlag.String(featureFlagName),
otel.WgOperationHash.String(item.OperationHash),
otel.WgOperationType.String(item.OperationType),
otel.WgEnginePlanCacheHit.Bool(false),
}, baseMetricAttributes...)...,
),
)
}
if s.Config.cacheWarmup.Source.Filesystem != nil {
warmupConfig.Source = NewFileSystemSource(&FileSystemSourceConfig{
RootPath: s.Config.cacheWarmup.Source.Filesystem.Path,
})
} else {
cdnSource, err := NewCDNSource(s.Config.cdnConfig.URL, s.graphApiToken, s.logger)
if err != nil {
return nil, fmt.Errorf("failed to create cdn source: %w", err)
}
warmupConfig.Source = cdnSource
}
err = WarmupCaches(ctx, warmupConfig)
if err != nil {
// We don't want to fail the server if the cache warmup fails
s.logger.Error("Failed to warmup caches. It will retry after server restart or graph execution config update", zap.Error(err))
}
}
authorizerOptions := &CosmoAuthorizerOptions{
FieldConfigurations: engineConfig.FieldConfigurations,
RejectOperationIfUnauthorized: false,
}
if s.Config.authorization != nil {
authorizerOptions.RejectOperationIfUnauthorized = s.authorization.RejectOperationIfUnauthorized
}
handlerOpts := HandlerOptions{
Executor: executor,
Log: s.logger,
EnableExecutionPlanCacheResponseHeader: s.engineExecutionConfiguration.EnableExecutionPlanCacheResponseHeader,
EnablePersistedOperationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnablePersistedOperationsCacheResponseHeader,
EnableNormalizationCacheResponseHeader: s.engineExecutionConfiguration.Debug.EnableNormalizationCacheResponseHeader,
EnableResponseHeaderPropagation: s.headerRules != nil,
EngineStats: s.engineStats,
TracerProvider: s.tracerProvider,
Authorizer: NewCosmoAuthorizer(authorizerOptions),
SubgraphErrorPropagation: s.subgraphErrorPropagation,
EngineLoaderHooks: NewEngineRequestHooks(gm.metricStore, subgraphAccessLogger, s.tracerProvider),
}
if s.redisClient != nil {
handlerOpts.RateLimitConfig = s.rateLimit
handlerOpts.RateLimiter, err = NewCosmoRateLimiter(&CosmoRateLimiterOptions{
RedisClient: s.redisClient,
Debug: s.rateLimit.Debug,
RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode,
KeySuffixExpression: s.rateLimit.KeySuffixExpression,
})
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter: %w", err)
}
}
if s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled {
handlerOpts.ApolloSubscriptionMultipartPrintBoundary = s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled
}
graphqlHandler := NewGraphQLHandler(handlerOpts)
executor.Resolver.SetAsyncErrorWriter(graphqlHandler)
operationBlocker, err := NewOperationBlocker(&OperationBlockerOptions{
BlockMutations: BlockMutationOptions{
Enabled: s.securityConfiguration.BlockMutations.Enabled,
Condition: s.securityConfiguration.BlockMutations.Condition,
},
BlockSubscriptions: BlockSubscriptionOptions{
Enabled: s.securityConfiguration.BlockSubscriptions.Enabled,
Condition: s.securityConfiguration.BlockSubscriptions.Condition,
},
BlockNonPersisted: BlockNonPersistedOptions{
Enabled: s.securityConfiguration.BlockNonPersistedOperations.Enabled,
Condition: s.securityConfiguration.BlockNonPersistedOperations.Condition,
},
SafelistEnabled: s.persistedOperationsConfig.Safelist.Enabled,
LogUnknownOperationsEnabled: s.persistedOperationsConfig.LogUnknown,
})
if err != nil {
return nil, fmt.Errorf("failed to create operation blocker: %w", err)
}
graphqlPreHandler := NewPreHandler(&PreHandlerOptions{
Logger: s.logger,
Executor: executor,
Metrics: metrics,
OperationProcessor: operationProcessor,
Planner: operationPlanner,
AccessController: s.accessController,
OperationBlocker: operationBlocker,
RouterPublicKey: s.publicKey,
EnableRequestTracing: s.engineExecutionConfiguration.EnableRequestTracing,
DevelopmentMode: s.developmentMode,
TracerProvider: s.tracerProvider,
FlushTelemetryAfterResponse: s.awsLambda,
TraceExportVariables: s.traceConfig.ExportGraphQLVariables.Enabled,
FileUploadEnabled: s.fileUploadConfig.Enabled,
MaxUploadFiles: s.fileUploadConfig.MaxFiles,
MaxUploadFileSize: int(s.fileUploadConfig.MaxFileSizeBytes),
ComplexityLimits: s.securityConfiguration.ComplexityLimits,
AlwaysIncludeQueryPlan: s.engineExecutionConfiguration.Debug.AlwaysIncludeQueryPlan,
AlwaysSkipLoader: s.engineExecutionConfiguration.Debug.AlwaysSkipLoader,
QueryPlansEnabled: s.Config.queryPlansEnabled,
QueryPlansLoggingEnabled: s.engineExecutionConfiguration.Debug.PrintQueryPlans,
TrackSchemaUsageInfo: s.graphqlMetricsConfig.Enabled,
ClientHeader: s.clientHeader,
ComputeOperationSha256: computeSha256,
ApolloCompatibilityFlags: &s.apolloCompatibilityFlags,
DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping,
})
if s.webSocketConfiguration != nil && s.webSocketConfiguration.Enabled {
wsMiddleware := NewWebsocketMiddleware(ctx, WebsocketMiddlewareOptions{
OperationProcessor: operationProcessor,
OperationBlocker: operationBlocker,
Planner: operationPlanner,
GraphQLHandler: graphqlHandler,
PreHandler: graphqlPreHandler,
Metrics: metrics,
AccessController: s.accessController,
Logger: s.logger,
Stats: s.engineStats,
ReadTimeout: s.engineExecutionConfiguration.WebSocketClientReadTimeout,
EnableNetPoll: s.engineExecutionConfiguration.EnableNetPoll,
NetPollTimeout: s.engineExecutionConfiguration.WebSocketClientPollTimeout,
NetPollConnBufferSize: s.engineExecutionConfiguration.WebSocketClientConnBufferSize,
WebSocketConfiguration: s.webSocketConfiguration,
ClientHeader: s.clientHeader,
Attributes: baseOtelAttributes,
DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
})
// When the playground path is equal to the graphql path, we need to handle
// ws upgrades and html requests on the same route.
if s.playgroundConfig.Enabled && s.graphqlPath == s.playgroundConfig.Path {
httpRouter.Use(s.playgroundHandler, wsMiddleware)
} else {
httpRouter.Use(wsMiddleware)
}
}
httpRouter.Use(
// Responsible for handling regular GraphQL requests over HTTP not WebSockets
graphqlPreHandler.Handler,
// Must be mounted after the websocket middleware to ensure that we only count non-hijacked requests like WebSockets
func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestContext := getRequestContext(r.Context())
// We don't want to count any type of subscriptions e.g. SSE as in-flight requests because they are long-lived
if requestContext != nil && requestContext.operation != nil && requestContext.operation.opType != OperationTypeSubscription {
s.inFlightRequests.Add(1)
// Counting like this is safe because according to the go http.ServeHTTP documentation
// the requests is guaranteed to be finished when ServeHTTP returns
defer s.inFlightRequests.Sub(1)
}
handler.ServeHTTP(w, r)
})
})
// Mount built global and custom modules
// Needs to be mounted after the pre-handler to ensure that the request was parsed and authorized
httpRouter.Use(s.routerMiddlewares...)
// GraphQL over POST
httpRouter.Post(s.graphqlPath, graphqlHandler.ServeHTTP)
// GraphQL over GET
httpRouter.Get(s.graphqlPath, graphqlHandler.ServeHTTP)
gm.mux = httpRouter
s.graphMuxListLock.Lock()
defer s.graphMuxListLock.Unlock()
s.graphMuxList = append(s.graphMuxList, gm)
return gm, nil
}
func (s *graphServer) buildPubSubConfiguration(ctx context.Context, engineConfig *nodev1.EngineConfiguration, routerEngineCfg *RouterEngineConfiguration) error {
datasourceConfigurations := engineConfig.GetDatasourceConfigurations()
for _, datasourceConfiguration := range datasourceConfigurations {
if datasourceConfiguration.CustomEvents == nil {
continue
}
for _, eventConfiguration := range datasourceConfiguration.GetCustomEvents().GetNats() {
providerID := eventConfiguration.EngineEventConfiguration.GetProviderId()
// if this source name's provider has already been initiated, do not try to initiate again
_, ok := s.pubSubProviders.nats[providerID]
if ok {
continue
}
for _, eventSource := range routerEngineCfg.Events.Providers.Nats {
if eventSource.ID == eventConfiguration.EngineEventConfiguration.GetProviderId() {
options, err := buildNatsOptions(eventSource, s.logger)
if err != nil {
return fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", providerID, err)
}
natsConnection, err := nats.Connect(eventSource.URL, options...)
if err != nil {
return fmt.Errorf("failed to create connection for Nats provider with ID \"%s\": %w", providerID, err)
}
js, err := jetstream.New(natsConnection)
if err != nil {
return err
}
s.pubSubProviders.nats[providerID] = pubsubNats.NewConnector(s.logger, natsConnection, js, s.hostName, s.routerListenAddr).New(ctx)
break
}
}
_, ok = s.pubSubProviders.nats[providerID]
if !ok {
return fmt.Errorf("failed to find Nats provider with ID \"%s\". Ensure the provider definition is part of the config", providerID)
}
}
for _, eventConfiguration := range datasourceConfiguration.GetCustomEvents().GetKafka() {
providerID := eventConfiguration.EngineEventConfiguration.GetProviderId()
// if this source name's provider has already been initiated, do not try to initiate again
_, ok := s.pubSubProviders.kafka[providerID]
if ok {
continue
}
for _, eventSource := range routerEngineCfg.Events.Providers.Kafka {
if eventSource.ID == providerID {
options, err := buildKafkaOptions(eventSource)
if err != nil {
return fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", providerID, err)
}
ps, err := kafka.NewConnector(s.logger, options)
if err != nil {
return fmt.Errorf("failed to create connection for Kafka provider with ID \"%s\": %w", providerID, err)
}
s.pubSubProviders.kafka[providerID] = ps.New(ctx)
break
}
}
_, ok = s.pubSubProviders.kafka[providerID]
if !ok {
return fmt.Errorf("failed to find Kafka provider with ID \"%s\". Ensure the provider definition is part of the config", providerID)
}
}
}
return nil
}
// wait waits for all in-flight requests to finish. Similar to http.Server.Shutdown we wait in intervals + jitter
// to make the shutdown process more efficient.
func (s *graphServer) wait(ctx context.Context) error {
b := backoff.New(500*time.Millisecond, time.Millisecond)
defer b.Reset()
timer := time.NewTimer(b.Duration())
defer timer.Stop()
for {
if s.inFlightRequests.Load() == 0 {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(b.Duration())
}
}
}
// Shutdown gracefully shutdown the server and waits for all in-flight requests to finish.
// After all requests are done, it will shut down the metric store and runtime metrics.
// Shutdown does cancel the context after all non-hijacked requests such as WebSockets has been handled.
func (s *graphServer) Shutdown(ctx context.Context) error {
// Cancel the context after the graceful shutdown is done
// to clean up resources like websocket connections, pools, etc.
defer s.cancelFunc()
s.logger.Debug("Shutdown of graph server initiated. Waiting for in-flight requests to finish.",
zap.String("config_version", s.baseRouterConfigVersion),
)
var finalErr error
// Wait for all in-flight requests to finish.
// In the worst case, we wait until the context is done or all requests has timed out.
if err := s.wait(ctx); err != nil {
s.logger.Error("Failed to wait for in-flight requests to finish", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
s.logger.Debug("Shutdown of graph server resources",
zap.String("grace_period", s.routerGracePeriod.String()),
zap.String("config_version", s.baseRouterConfigVersion),
)
// Ensure that we don't wait indefinitely for shutdown
if s.routerGracePeriod > 0 {
newCtx, cancel := context.WithTimeout(ctx, s.routerGracePeriod)
defer cancel()
ctx = newCtx
}
if s.runtimeMetrics != nil {
if err := s.runtimeMetrics.Shutdown(); err != nil {
s.logger.Error("Failed to shutdown runtime metrics", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
if s.otlpEngineMetrics != nil {
if err := s.otlpEngineMetrics.Shutdown(); err != nil {
s.logger.Error("Failed to shutdown OTLP engine metrics", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
if s.prometheusEngineMetrics != nil {
if err := s.prometheusEngineMetrics.Shutdown(); err != nil {
s.logger.Error("Failed to shutdown Prometheus engine metrics", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
if s.pubSubProviders != nil {
s.logger.Debug("Shutting down pubsub providers")
for _, pubSub := range s.pubSubProviders.nats {
if p, ok := pubSub.(pubsub.Lifecycle); ok {
if err := p.Shutdown(ctx); err != nil {
s.logger.Error("Failed to shutdown Nats pubsub provider", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
}
for _, pubSub := range s.pubSubProviders.kafka {
if p, ok := pubSub.(pubsub.Lifecycle); ok {
if err := p.Shutdown(ctx); err != nil {
s.logger.Error("Failed to shutdown Kafka pubsub provider", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
}
}
// Shutdown all graphs muxes to release resources
// e.g. planner cache
s.graphMuxListLock.Lock()
defer s.graphMuxListLock.Unlock()
for _, mux := range s.graphMuxList {
if err := mux.Shutdown(ctx); err != nil {
s.logger.Error("Failed to shutdown graph mux", zap.Error(err))
finalErr = errors.Join(finalErr, err)
}
}
return finalErr
}
func configureSubgraphOverwrites(
engineConfig *nodev1.EngineConfiguration,
configSubgraphs []*nodev1.Subgraph,
overrideRoutingURLConfig config.OverrideRoutingURLConfiguration,
overrides config.OverridesConfiguration,
) ([]Subgraph, error) {
var err error
subgraphs := make([]Subgraph, 0, len(configSubgraphs))
for _, sg := range configSubgraphs {
subgraph := Subgraph{
Id: sg.Id,
Name: sg.Name,
}
// Validate subgraph url. Note that it can be empty if the subgraph is virtual
subgraph.Url, err = url.Parse(sg.RoutingUrl)
if err != nil {
return nil, fmt.Errorf("failed to parse subgraph url '%s': %w", sg.RoutingUrl, err)
}
subgraph.UrlString = subgraph.Url.String()
overrideURL, ok := overrideRoutingURLConfig.Subgraphs[sg.Name]
overrideSubgraph, overrideSubgraphOk := overrides.Subgraphs[sg.Name]
var overrideSubscriptionURL string
var overrideSubscriptionProtocol *common.GraphQLSubscriptionProtocol
var overrideSubscriptionWebsocketSubprotocol *common.GraphQLWebsocketSubprotocol
if overrideSubgraphOk {
if overrideSubgraph.RoutingURL != "" {
overrideURL = overrideSubgraph.RoutingURL
}
if overrideSubgraph.SubscriptionURL != "" {
overrideSubscriptionURL = overrideSubgraph.SubscriptionURL
_, err := url.Parse(overrideSubscriptionURL)
if err != nil {
return nil, fmt.Errorf("failed to parse override url '%s': %w", overrideSubscriptionURL, err)
}
}
if overrideSubgraph.SubscriptionProtocol != "" {
switch overrideSubgraph.SubscriptionProtocol {
case "ws":
overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_WS.Enum()
case "sse":
overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE.Enum()
case "sse_post":
overrideSubscriptionProtocol = common.GraphQLSubscriptionProtocol_GRAPHQL_SUBSCRIPTION_PROTOCOL_SSE_POST.Enum()
default:
return nil, fmt.Errorf("invalid subscription protocol '%s'", overrideSubgraph.SubscriptionProtocol)
}
}
if overrideSubgraph.SubscriptionWebsocketSubprotocol != "" {
switch overrideSubgraph.SubscriptionWebsocketSubprotocol {
case "graphql-ws":
overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_WS.Enum()
case "graphql-transport-ws":
overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_TRANSPORT_WS.Enum()
case "auto":
overrideSubscriptionWebsocketSubprotocol = common.GraphQLWebsocketSubprotocol_GRAPHQL_WEBSOCKET_SUBPROTOCOL_AUTO.Enum()
default:
return nil, fmt.Errorf("invalid subscription websocket subprotocol '%s'", overrideSubgraph.SubscriptionWebsocketSubprotocol)
}
}
}
// check if the subgraph is overridden
if ok || overrideSubgraphOk {
if overrideURL != "" {
subgraph.Url, err = url.Parse(overrideURL)
if err != nil {
return nil, fmt.Errorf("failed to parse override url '%s': %w", overrideURL, err)
}
subgraph.UrlString = subgraph.Url.String()
}
// Override datasource urls
for _, conf := range engineConfig.DatasourceConfigurations {
if conf.Id == sg.Id {
if overrideURL != "" {
conf.CustomGraphql.Fetch.Url.StaticVariableContent = overrideURL
sg.RoutingUrl = overrideURL
}
if overrideSubscriptionURL != "" {
conf.CustomGraphql.Subscription.Url.StaticVariableContent = overrideSubscriptionURL
}
if overrideSubscriptionProtocol != nil {
conf.CustomGraphql.Subscription.Protocol = overrideSubscriptionProtocol
}
if overrideSubscriptionWebsocketSubprotocol != nil {
conf.CustomGraphql.Subscription.WebsocketSubprotocol = overrideSubscriptionWebsocketSubprotocol
}
break
}
}
}
subgraphs = append(subgraphs, subgraph)
}
return subgraphs, nil
}