router/core/router.go (1,657 lines of code) (raw):
package core
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"sync"
"time"
rd "github.com/wundergraph/cosmo/router/internal/persistedoperation/operationstorage/redis"
"connectrpc.com/connect"
"github.com/mitchellh/mapstructure"
"github.com/nats-io/nuid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1/graphqlmetricsv1connect"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/internal/debug"
"github.com/wundergraph/cosmo/router/internal/docker"
"github.com/wundergraph/cosmo/router/internal/graphiql"
"github.com/wundergraph/cosmo/router/internal/graphqlmetrics"
"github.com/wundergraph/cosmo/router/internal/persistedoperation"
"github.com/wundergraph/cosmo/router/internal/persistedoperation/apq"
"github.com/wundergraph/cosmo/router/internal/persistedoperation/operationstorage/cdn"
"github.com/wundergraph/cosmo/router/internal/persistedoperation/operationstorage/s3"
"github.com/wundergraph/cosmo/router/internal/retrytransport"
"github.com/wundergraph/cosmo/router/internal/stringsx"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller"
"github.com/wundergraph/cosmo/router/pkg/controlplane/selfregister"
"github.com/wundergraph/cosmo/router/pkg/cors"
"github.com/wundergraph/cosmo/router/pkg/execution_config"
"github.com/wundergraph/cosmo/router/pkg/health"
rmetric "github.com/wundergraph/cosmo/router/pkg/metric"
"github.com/wundergraph/cosmo/router/pkg/otel/otelconfig"
"github.com/wundergraph/cosmo/router/pkg/statistics"
rtrace "github.com/wundergraph/cosmo/router/pkg/trace"
"github.com/wundergraph/cosmo/router/pkg/watcher"
"github.com/wundergraph/graphql-go-tools/v2/pkg/netpoll"
)
type IPAnonymizationMethod string
const (
Hash IPAnonymizationMethod = "hash"
Redact IPAnonymizationMethod = "redact"
)
var CompressibleContentTypes = []string{
"text/html",
"text/css",
"text/plain",
"text/javascript",
"application/javascript",
"application/x-javascript",
"application/json",
"application/atom+xml",
"application/rss+xml",
"image/svg+xml",
"application/graphql",
"application/graphql-response+json",
"application/graphql+json",
}
type (
// Router is the main application instance.
Router struct {
Config
httpServer *server
modules []Module
EngineStats statistics.EngineStatistics
playgroundHandler func(http.Handler) http.Handler
proxy ProxyFunc
}
TransportRequestOptions struct {
RequestTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
KeepAliveIdleTimeout time.Duration
DialTimeout time.Duration
TLSHandshakeTimeout time.Duration
KeepAliveProbeInterval time.Duration
MaxConnsPerHost int
MaxIdleConns int
MaxIdleConnsPerHost int
}
SubgraphTransportOptions struct {
*TransportRequestOptions
SubgraphMap map[string]*TransportRequestOptions
}
GraphQLMetricsConfig struct {
Enabled bool
CollectorEndpoint string
}
IPAnonymizationConfig struct {
Enabled bool
Method IPAnonymizationMethod
}
TlsClientAuthConfig struct {
Required bool
CertFile string
}
TlsConfig struct {
Enabled bool
CertFile string
KeyFile string
ClientAuth *TlsClientAuthConfig
}
RouterConfigPollerConfig struct {
config.ExecutionConfig
PollInterval time.Duration
PollJitter time.Duration
GraphSignKey string
}
ExecutionConfig struct {
Watch bool
WatchInterval time.Duration
Path string
}
AccessLogsConfig struct {
Attributes []config.CustomAttribute
Logger *zap.Logger
SubgraphEnabled bool
SubgraphAttributes []config.CustomAttribute
}
// Config defines the configuration options for the Router.
Config struct {
clusterName string
instanceID string
logger *zap.Logger
traceConfig *rtrace.Config
metricConfig *rmetric.Config
tracerProvider *sdktrace.TracerProvider
otlpMeterProvider *sdkmetric.MeterProvider
promMeterProvider *sdkmetric.MeterProvider
gqlMetricsExporter *graphqlmetrics.Exporter
corsOptions *cors.Config
setConfigVersionHeader bool
routerGracePeriod time.Duration
staticExecutionConfig *nodev1.RouterConfig
awsLambda bool
shutdown atomic.Bool
bootstrapped atomic.Bool
ipAnonymization *IPAnonymizationConfig
listenAddr string
baseURL string
graphqlWebURL string
playgroundPath string
graphqlPath string
playground bool
introspection bool
queryPlansEnabled bool
graphApiToken string
healthCheckPath string
readinessCheckPath string
livenessCheckPath string
playgroundConfig config.PlaygroundConfig
cacheControlPolicy config.CacheControlPolicy
routerConfigPollerConfig *RouterConfigPollerConfig
cdnConfig config.CDNConfiguration
persistedOperationClient persistedoperation.SaveClient
persistedOperationsConfig config.PersistedOperationsConfig
automaticPersistedQueriesConfig config.AutomaticPersistedQueriesConfig
apolloCompatibilityFlags config.ApolloCompatibilityFlags
apolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags
storageProviders config.StorageProviders
eventsConfig config.EventsConfiguration
prometheusServer *http.Server
modulesConfig map[string]interface{}
executionConfig *ExecutionConfig
routerOnRequestHandlers []func(http.Handler) http.Handler
routerMiddlewares []func(http.Handler) http.Handler
preOriginHandlers []TransportPreHandler
postOriginHandlers []TransportPostHandler
headerRules *config.HeaderRules
subgraphTransportOptions *SubgraphTransportOptions
graphqlMetricsConfig *GraphQLMetricsConfig
routerTrafficConfig *config.RouterTrafficConfiguration
fileUploadConfig *config.FileUpload
accessController *AccessController
retryOptions retrytransport.RetryOptions
redisClient rd.RDCloser
processStartTime time.Time
developmentMode bool
healthcheck health.Checker
accessLogsConfig *AccessLogsConfig
// If connecting to localhost inside Docker fails, fallback to the docker internal address for the host
localhostFallbackInsideDocker bool
tlsServerConfig *tls.Config
tlsConfig *TlsConfig
telemetryAttributes []config.CustomAttribute
tracePropagators []propagation.TextMapPropagator
compositePropagator propagation.TextMapPropagator
// Poller
configPoller configpoller.ConfigPoller
selfRegister selfregister.SelfRegister
registrationInfo *nodev1.RegistrationInfo
securityConfiguration config.SecurityConfiguration
customModules []Module
engineExecutionConfiguration config.EngineExecutionConfiguration
// should be removed once the users have migrated to the new overrides config
overrideRoutingURLConfiguration config.OverrideRoutingURLConfiguration
// the new overrides config
overrides config.OverridesConfiguration
authorization *config.AuthorizationConfiguration
rateLimit *config.RateLimitConfiguration
webSocketConfiguration *config.WebSocketConfiguration
subgraphErrorPropagation config.SubgraphErrorPropagationConfiguration
clientHeader config.ClientHeader
cacheWarmup *config.CacheWarmupConfiguration
multipartHeartbeatInterval time.Duration
hostName string
}
// Option defines the method to customize server.
Option func(svr *Router)
)
// NewRouter creates a new Router instance. Router.Start() must be called to start the server.
// Alternatively, use Router.NewServer() to create a new server instance without starting it.
func NewRouter(opts ...Option) (*Router, error) {
r := &Router{
EngineStats: statistics.NewNoopEngineStats(),
}
for _, opt := range opts {
opt(r)
}
if r.logger == nil {
r.logger = zap.NewNop()
}
// Default value for graphql path
if r.graphqlPath == "" {
r.graphqlPath = "/graphql"
}
if r.graphqlWebURL == "" {
r.graphqlWebURL = r.graphqlPath
}
// this is set via the deprecated method
if !r.playground {
r.playgroundConfig.Enabled = r.playground
r.logger.Warn("The playground_enabled option is deprecated. Use the playground.enabled option in the config instead.")
}
if r.playgroundPath != "" && r.playgroundPath != "/" {
r.playgroundConfig.Path = r.playgroundPath
r.logger.Warn("The playground_path option is deprecated. Use the playground.path option in the config instead.")
}
if r.playgroundConfig.Path == "" {
r.playgroundConfig.Path = "/"
}
if r.instanceID == "" {
r.instanceID = nuid.Next()
}
r.processStartTime = time.Now()
// Create noop tracer and meter to avoid nil pointer panics and to avoid checking for nil everywhere
r.tracerProvider = sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.NeverSample()))
r.otlpMeterProvider = sdkmetric.NewMeterProvider()
r.promMeterProvider = sdkmetric.NewMeterProvider()
// Default values for trace and metric config
if r.traceConfig == nil {
r.traceConfig = rtrace.DefaultConfig(Version)
}
if r.metricConfig == nil {
r.metricConfig = rmetric.DefaultConfig(Version)
}
if r.corsOptions == nil {
r.corsOptions = CorsDefaultOptions()
}
if r.subgraphTransportOptions == nil {
r.subgraphTransportOptions = DefaultSubgraphTransportOptions()
}
if r.graphqlMetricsConfig == nil {
r.graphqlMetricsConfig = DefaultGraphQLMetricsConfig()
}
if r.routerTrafficConfig == nil {
r.routerTrafficConfig = DefaultRouterTrafficConfig()
}
if r.fileUploadConfig == nil {
r.fileUploadConfig = DefaultFileUploadConfig()
}
if r.accessController != nil {
if len(r.accessController.authenticators) == 0 && r.accessController.authenticationRequired {
r.logger.Warn("authentication is required but no authenticators are configured")
}
}
if r.ipAnonymization == nil {
r.ipAnonymization = &IPAnonymizationConfig{
Enabled: true,
Method: Redact,
}
}
// Default values for health check paths
if r.healthCheckPath == "" {
r.healthCheckPath = "/health"
}
if r.readinessCheckPath == "" {
r.readinessCheckPath = "/health/ready"
}
if r.livenessCheckPath == "" {
r.livenessCheckPath = "/health/live"
}
r.headerRules = AddCacheControlPolicyToRules(r.headerRules, r.cacheControlPolicy)
hr, err := NewHeaderPropagation(r.headerRules)
if err != nil {
return nil, err
}
if hr.HasRequestRules() {
r.preOriginHandlers = append(r.preOriginHandlers, hr.OnOriginRequest)
}
if hr.HasResponseRules() {
r.postOriginHandlers = append(r.postOriginHandlers, hr.OnOriginResponse)
}
defaultHeaders := []string{
// Common headers
"authorization",
"origin",
"content-length",
"content-type",
// Semi standard client info headers
"graphql-client-name",
"graphql-client-version",
// Apollo client info headers
"apollographql-client-name",
"apollographql-client-version",
// Required for WunderGraph ART
"x-wg-trace",
"x-wg-disable-tracing",
"x-wg-token",
"x-wg-skip-loader",
"x-wg-include-query-plan",
// Required for Trace Context propagation
"traceparent",
"tracestate",
// Required for feature flags
"x-feature-flag",
}
if r.clientHeader.Name != "" {
defaultHeaders = append(defaultHeaders, r.clientHeader.Name)
}
if r.clientHeader.Version != "" {
defaultHeaders = append(defaultHeaders, r.clientHeader.Version)
}
defaultMethods := []string{
"HEAD", "GET", "POST",
}
r.corsOptions.AllowHeaders = stringsx.RemoveDuplicates(append(r.corsOptions.AllowHeaders, defaultHeaders...))
r.corsOptions.AllowMethods = stringsx.RemoveDuplicates(append(r.corsOptions.AllowMethods, defaultMethods...))
if r.tlsConfig != nil && r.tlsConfig.Enabled {
r.baseURL = fmt.Sprintf("https://%s", r.listenAddr)
} else {
r.baseURL = fmt.Sprintf("http://%s", r.listenAddr)
}
if r.tlsConfig != nil && r.tlsConfig.Enabled {
if r.tlsConfig.CertFile == "" {
return nil, errors.New("tls cert file not provided")
}
if r.tlsConfig.KeyFile == "" {
return nil, errors.New("tls key file not provided")
}
var caCertPool *x509.CertPool
clientAuthMode := tls.NoClientCert
if r.tlsConfig.ClientAuth != nil && r.tlsConfig.ClientAuth.CertFile != "" {
caCert, err := os.ReadFile(r.tlsConfig.ClientAuth.CertFile)
if err != nil {
return nil, fmt.Errorf("failed to read cert file: %w", err)
}
// Create a CA an empty cert pool and add the CA cert to it to serve as authority to validate client certs
caPool := x509.NewCertPool()
if ok := caPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("failed to append cert to pool")
}
caCertPool = caPool
if r.tlsConfig.ClientAuth.Required {
clientAuthMode = tls.RequireAndVerifyClientCert
} else {
clientAuthMode = tls.VerifyClientCertIfGiven
}
r.logger.Debug("Client auth enabled", zap.String("mode", clientAuthMode.String()))
}
// Load the server cert and private key
cer, err := tls.LoadX509KeyPair(r.tlsConfig.CertFile, r.tlsConfig.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load tls cert and key: %w", err)
}
r.tlsServerConfig = &tls.Config{
ClientCAs: caCertPool,
Certificates: []tls.Certificate{cer},
ClientAuth: clientAuthMode,
}
}
if r.traceConfig.Enabled {
if len(r.traceConfig.Propagators) > 0 {
propagators, err := rtrace.BuildPropagators(r.traceConfig.Propagators...)
if err != nil {
r.logger.Error("creating propagators", zap.Error(err))
return nil, err
}
r.tracePropagators = propagators
}
// Add default tracing exporter if needed
if len(r.traceConfig.Exporters) == 0 && r.traceConfig.TestMemoryExporter == nil {
if endpoint := otelconfig.DefaultEndpoint(); endpoint != "" {
r.logger.Debug("Using default trace exporter", zap.String("endpoint", endpoint))
r.traceConfig.Exporters = append(r.traceConfig.Exporters, &rtrace.ExporterConfig{
Endpoint: endpoint,
Exporter: otelconfig.ExporterOLTPHTTP,
HTTPPath: "/v1/traces",
Headers: otelconfig.DefaultEndpointHeaders(r.graphApiToken),
})
}
}
}
// Add default metric exporter if none are configured
if r.metricConfig.OpenTelemetry.Enabled && len(r.metricConfig.OpenTelemetry.Exporters) == 0 && r.metricConfig.OpenTelemetry.TestReader == nil {
if endpoint := otelconfig.DefaultEndpoint(); endpoint != "" {
r.logger.Debug("Using default metrics exporter", zap.String("endpoint", endpoint))
r.metricConfig.OpenTelemetry.Exporters = append(r.metricConfig.OpenTelemetry.Exporters, &rmetric.OpenTelemetryExporter{
Endpoint: endpoint,
Exporter: otelconfig.ExporterOLTPHTTP,
HTTPPath: "/v1/metrics",
Headers: otelconfig.DefaultEndpointHeaders(r.graphApiToken),
})
}
}
var disabledFeatures []string
// The user might want to start the server with a static config
// Disable all features that requires a valid graph token and inform the user
if r.graphApiToken == "" {
r.graphqlMetricsConfig.Enabled = false
disabledFeatures = append(disabledFeatures, "Schema Usage Tracking", "Persistent operations")
if !r.developmentMode {
disabledFeatures = append(disabledFeatures, "Advanced Request Tracing")
}
if r.traceConfig.Enabled {
defaultExporter := rtrace.DefaultExporter(r.traceConfig)
if defaultExporter != nil {
disabledFeatures = append(disabledFeatures, "Cosmo Cloud Tracing")
defaultExporter.Disabled = true
}
}
if r.metricConfig.OpenTelemetry.Enabled {
defaultExporter := rmetric.GetDefaultExporter(r.metricConfig)
if defaultExporter != nil {
disabledFeatures = append(disabledFeatures, "Cosmo Cloud Metrics")
defaultExporter.Disabled = true
}
}
r.logger.Warn("No graph token provided. The following Cosmo Cloud features are disabled. Not recommended for Production.",
zap.Strings("features", disabledFeatures),
)
}
if r.persistedOperationsConfig.Safelist.Enabled && r.automaticPersistedQueriesConfig.Enabled {
return nil, errors.New("automatic persisted queries and safelist cannot be enabled at the same time (as APQ would permit queries that are not in the safelist)")
}
if r.securityConfiguration.DepthLimit != nil {
r.logger.Warn("The security configuration field 'depth_limit' is deprecated, and will be removed. Use 'security.complexity_limits.depth' instead.")
if r.securityConfiguration.ComplexityCalculationCache == nil {
r.securityConfiguration.ComplexityCalculationCache = &config.ComplexityCalculationCache{
Enabled: true,
CacheSize: r.securityConfiguration.DepthLimit.CacheSize,
}
}
if r.securityConfiguration.ComplexityLimits == nil {
r.securityConfiguration.ComplexityLimits = &config.ComplexityLimits{}
}
if r.securityConfiguration.ComplexityLimits.Depth == nil {
r.securityConfiguration.ComplexityLimits.Depth = &config.ComplexityLimit{
Enabled: r.securityConfiguration.DepthLimit.Enabled,
Limit: r.securityConfiguration.DepthLimit.Limit,
IgnorePersistedOperations: r.securityConfiguration.DepthLimit.IgnorePersistedOperations,
}
} else {
r.logger.Warn("Ignoring deprecated security configuration field 'depth_limit', in favor of the `security_complexity_limits.depth` configuration")
}
}
if r.developmentMode {
r.logger.Warn("Development mode enabled. This should only be used for testing purposes")
}
if r.healthcheck == nil {
r.healthcheck = health.New(&health.Options{
Logger: r.logger,
})
}
for _, source := range r.eventsConfig.Providers.Nats {
r.logger.Info("Nats Event source enabled", zap.String("provider_id", source.ID))
}
for _, source := range r.eventsConfig.Providers.Kafka {
r.logger.Info("Kafka Event source enabled", zap.String("provider_id", source.ID), zap.Strings("brokers", source.Brokers))
}
if !r.engineExecutionConfiguration.EnableNetPoll {
r.logger.Warn("Net poller is disabled by configuration. Falling back to less efficient connection handling method.")
} else if err := netpoll.Supported(); err != nil {
// Disable netPoll if it's not supported. This flag is used everywhere to decide whether to use netPoll or not.
r.engineExecutionConfiguration.EnableNetPoll = false
if errors.Is(err, netpoll.ErrUnsupported) {
r.logger.Warn(
"Net poller is only available on Linux and MacOS. Falling back to less efficient connection handling method.",
zap.Error(err),
)
} else {
r.logger.Warn(
"Net poller is not functional by the environment. Ensure that the system supports epoll/kqueue and that necessary syscall permissions are granted. Falling back to less efficient connection handling method.",
zap.Error(err),
)
}
}
if r.hostName == "" {
r.hostName, err = os.Hostname()
if err != nil {
r.logger.Warn("Failed to get hostname", zap.Error(err))
}
}
return r, nil
}
// newGraphServer creates a new server.
func (r *Router) newServer(ctx context.Context, cfg *nodev1.RouterConfig) error {
server, err := newGraphServer(ctx, r, cfg, r.proxy)
if err != nil {
r.logger.Error("Failed to create graph server. Keeping the old server", zap.Error(err))
return err
}
r.httpServer.SwapGraphServer(ctx, server)
return nil
}
func (r *Router) listenAndServe() error {
go func() {
// Mark the server as not ready when the server is stopped
defer r.httpServer.healthcheck.SetReady(false)
// This is a blocking call
if err := r.httpServer.listenAndServe(); err != nil {
r.logger.Error("Failed to start new server", zap.Error(err))
}
}()
return nil
}
func (r *Router) initModules(ctx context.Context) error {
moduleList := make([]ModuleInfo, 0, len(modules)+len(r.customModules))
for _, module := range modules {
moduleList = append(moduleList, module)
}
for _, module := range r.customModules {
moduleList = append(moduleList, module.Module())
}
moduleList = sortModules(moduleList)
for _, moduleInfo := range moduleList {
now := time.Now()
moduleInstance := moduleInfo.New()
mc := &ModuleContext{
Context: ctx,
Module: moduleInstance,
Logger: r.logger.With(zap.String("module", string(moduleInfo.ID))),
}
moduleConfig, ok := r.modulesConfig[string(moduleInfo.ID)]
if ok {
if err := mapstructure.Decode(moduleConfig, &moduleInstance); err != nil {
return fmt.Errorf("failed to decode module config from module %s: %w", moduleInfo.ID, err)
}
} else {
r.logger.Debug("No config found for module", zap.String("id", string(moduleInfo.ID)))
}
if fn, ok := moduleInstance.(Provisioner); ok {
if err := fn.Provision(mc); err != nil {
return fmt.Errorf("failed to provision module '%s': %w", moduleInfo.ID, err)
}
}
if fn, ok := moduleInstance.(RouterMiddlewareHandler); ok {
r.routerMiddlewares = append(r.routerMiddlewares, func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
reqContext := getRequestContext(request.Context())
// Ensure we work with latest request in the chain to work with the right context
reqContext.request = request
fn.Middleware(reqContext, handler)
})
})
}
if fn, ok := moduleInstance.(RouterOnRequestHandler); ok {
r.routerOnRequestHandlers = append(r.routerOnRequestHandlers, func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
reqContext := getRequestContext(request.Context())
// Ensure we work with latest request in the chain to work with the right context
reqContext.request = request
fn.RouterOnRequest(reqContext, handler)
})
})
}
if handler, ok := moduleInstance.(EnginePreOriginHandler); ok {
r.preOriginHandlers = append(r.preOriginHandlers, handler.OnOriginRequest)
}
if handler, ok := moduleInstance.(EnginePostOriginHandler); ok {
r.postOriginHandlers = append(r.postOriginHandlers, handler.OnOriginResponse)
}
if handler, ok := moduleInstance.(TracePropagationProvider); ok {
modulePropagators := handler.TracePropagators()
if len(modulePropagators) > 0 {
r.tracePropagators = append(r.tracePropagators, modulePropagators...)
}
}
r.modules = append(r.modules, moduleInstance)
r.logger.Info("Module registered",
zap.String("id", string(moduleInfo.ID)),
zap.String("duration", time.Since(now).String()),
)
}
return nil
}
func (r *Router) BaseURL() string {
return r.baseURL
}
// NewServer prepares a new server instance but does not start it. The method should only be used when you want to bootstrap
// the server manually otherwise you can use Router.Start(). You're responsible for setting health checks status to ready with Server.HealthChecks().
// The server can be shutdown with Router.Shutdown(). Use core.WithExecutionConfig to pass the initial config otherwise the Router will
// try to fetch the config from the control plane. You can swap the router config by using Router.newGraphServer().
func (r *Router) NewServer(ctx context.Context) (Server, error) {
if r.shutdown.Load() {
return nil, fmt.Errorf("router is shutdown. Create a new instance with router.NewRouter()")
}
if err := r.bootstrap(ctx); err != nil {
return nil, fmt.Errorf("failed to bootstrap application: %w", err)
}
r.httpServer = newServer(&httpServerOptions{
addr: r.listenAddr,
logger: r.logger,
tlsConfig: r.tlsConfig,
tlsServerConfig: r.tlsServerConfig,
healthcheck: r.healthcheck,
baseURL: r.baseURL,
maxHeaderBytes: int(r.routerTrafficConfig.MaxHeaderBytes.Uint64()),
livenessCheckPath: r.livenessCheckPath,
readinessCheckPath: r.readinessCheckPath,
healthCheckPath: r.healthCheckPath,
})
// Start the server with the static config without polling
if r.staticExecutionConfig != nil {
r.logger.Info("Static execution config provided. Polling is disabled. Updating execution config is only possible by providing a config.")
return r.httpServer, r.newServer(ctx, r.staticExecutionConfig)
}
// when no static config is provided and no poller is configured, we can't start the server
if r.configPoller == nil {
return nil, fmt.Errorf("config fetcher not provided. Please provide a static execution config instead")
}
cfg, err := r.configPoller.GetRouterConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get initial execution config: %w", err)
}
if err := r.newServer(ctx, cfg.Config); err != nil {
r.logger.Error("Failed to start server with initial config", zap.Error(err))
return nil, err
}
return r.httpServer, nil
}
// bootstrap initializes the Router. It is called by Start() and NewServer().
// It should only be called once for a Router instance.
func (r *Router) bootstrap(ctx context.Context) error {
if !r.bootstrapped.CompareAndSwap(false, true) {
return fmt.Errorf("router is already bootstrapped")
}
cosmoCloudTracingEnabled := r.traceConfig.Enabled && rtrace.DefaultExporter(r.traceConfig) != nil
artInProductionEnabled := r.engineExecutionConfiguration.EnableRequestTracing && !r.developmentMode
needsRegistration := cosmoCloudTracingEnabled || artInProductionEnabled
if needsRegistration && r.selfRegister != nil {
r.logger.Info("Registering router with control plane because you opted in to send telemetry to Cosmo Cloud or advanced request tracing (ART) in production")
ri, registerErr := r.selfRegister.Register(ctx)
if registerErr != nil {
r.logger.Warn("Failed to register router on the control plane. If this warning persists, please contact support.")
} else {
r.registrationInfo = ri
// Only ensure sampling rate if the user exports traces to Cosmo Cloud
if cosmoCloudTracingEnabled {
if r.traceConfig.Sampler > float64(r.registrationInfo.AccountLimits.TraceSamplingRate) {
r.logger.Warn("Trace sampling rate is higher than account limit. Using account limit instead. Please contact support to increase your account limit.",
zap.Float64("limit", r.traceConfig.Sampler),
zap.String("account_limit", fmt.Sprintf("%.2f", r.registrationInfo.AccountLimits.TraceSamplingRate)),
)
r.traceConfig.Sampler = float64(r.registrationInfo.AccountLimits.TraceSamplingRate)
}
}
}
}
if r.traceConfig.Enabled {
tp, err := rtrace.NewTracerProvider(ctx, &rtrace.ProviderConfig{
Logger: r.logger,
Config: r.traceConfig,
ServiceInstanceID: r.instanceID,
IPAnonymization: &rtrace.IPAnonymizationConfig{
Enabled: r.ipAnonymization.Enabled,
Method: rtrace.IPAnonymizationMethod(r.ipAnonymization.Method),
},
MemoryExporter: r.traceConfig.TestMemoryExporter,
})
if err != nil {
return fmt.Errorf("failed to start trace agent: %w", err)
}
r.tracerProvider = tp
}
// Prometheus metrics rely on OTLP metrics
if r.metricConfig.IsEnabled() {
if r.metricConfig.Prometheus.Enabled {
mp, registry, err := rmetric.NewPrometheusMeterProvider(ctx, r.metricConfig, r.instanceID)
if err != nil {
return fmt.Errorf("failed to create Prometheus exporter: %w", err)
}
r.promMeterProvider = mp
r.prometheusServer = rmetric.NewPrometheusServer(r.logger, r.metricConfig.Prometheus.ListenAddr, r.metricConfig.Prometheus.Path, registry)
go func() {
if err := r.prometheusServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
r.logger.Error("Failed to start Prometheus server", zap.Error(err))
}
}()
}
if r.metricConfig.OpenTelemetry.Enabled {
mp, err := rmetric.NewOtlpMeterProvider(ctx, r.logger, r.metricConfig, r.instanceID)
if err != nil {
return fmt.Errorf("failed to start trace agent: %w", err)
}
r.otlpMeterProvider = mp
}
}
if r.graphqlMetricsConfig.Enabled {
client := graphqlmetricsv1connect.NewGraphQLMetricsServiceClient(
http.DefaultClient,
r.graphqlMetricsConfig.CollectorEndpoint,
connect.WithSendGzip(),
)
ge, err := graphqlmetrics.NewExporter(
r.logger,
client,
r.graphApiToken,
graphqlmetrics.NewDefaultExporterSettings(),
)
if err != nil {
return fmt.Errorf("failed to validate graphql metrics exporter: %w", err)
}
r.gqlMetricsExporter = ge
r.logger.Info("GraphQL schema coverage metrics enabled")
}
if r.Config.rateLimit != nil && r.Config.rateLimit.Enabled {
var err error
r.redisClient, err = rd.NewRedisCloser(&rd.RedisCloserOptions{
URLs: r.Config.rateLimit.Storage.URLs,
ClusterEnabled: r.Config.rateLimit.Storage.ClusterEnabled,
Logger: r.logger,
})
if err != nil {
return fmt.Errorf("failed to create redis client: %w", err)
}
}
if r.metricConfig.OpenTelemetry.EngineStats.Enabled() || r.metricConfig.Prometheus.EngineStats.Enabled() || r.engineExecutionConfiguration.Debug.ReportWebSocketConnections {
r.EngineStats = statistics.NewEngineStats(ctx, r.logger, r.engineExecutionConfiguration.Debug.ReportWebSocketConnections)
}
if r.engineExecutionConfiguration.Debug.ReportMemoryUsage {
debug.ReportMemoryUsage(ctx, r.logger)
}
if r.playgroundConfig.Enabled {
playgroundUrl, err := url.JoinPath(r.baseURL, r.playgroundConfig.Path)
if err != nil {
return fmt.Errorf("failed to join playground url: %w", err)
}
r.logger.Info("Serving GraphQL playground", zap.String("url", playgroundUrl))
r.playgroundHandler = graphiql.NewPlayground(&graphiql.PlaygroundOptions{
Html: graphiql.PlaygroundHTML(),
GraphqlURL: r.graphqlWebURL,
PlaygroundPath: r.playgroundPath,
ConcurrencyLimit: int64(r.playgroundConfig.ConcurrencyLimit),
})
}
if r.executionConfig != nil && r.executionConfig.Path != "" {
executionConfig, err := execution_config.FromFile(r.executionConfig.Path)
if err != nil {
return fmt.Errorf("failed to read execution config: %w", err)
}
r.staticExecutionConfig = executionConfig
}
if err := r.buildClients(); err != nil {
return err
}
// Modules are only initialized once and not on every config change
if err := r.initModules(ctx); err != nil {
return fmt.Errorf("failed to init user modules: %w", err)
}
if r.traceConfig.Enabled && len(r.tracePropagators) > 0 {
r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...)
// Don't set it globally when we use the router in tests.
// In practice, setting it globally only makes sense for module development.
if r.traceConfig.TestMemoryExporter == nil {
otel.SetTextMapPropagator(r.compositePropagator)
}
}
return nil
}
// buildClients initializes the storage clients for persisted operations and router config.
func (r *Router) buildClients() error {
s3Providers := map[string]config.S3StorageProvider{}
cdnProviders := map[string]config.BaseStorageProvider{}
redisProviders := map[string]config.RedisStorageProvider{}
for _, provider := range r.storageProviders.S3 {
if _, ok := s3Providers[provider.ID]; ok {
return fmt.Errorf("duplicate s3 storage provider with id '%s'", provider.ID)
}
s3Providers[provider.ID] = provider
}
for _, provider := range r.storageProviders.CDN {
if _, ok := cdnProviders[provider.ID]; ok {
return fmt.Errorf("duplicate cdn storage provider with id '%s'", provider.ID)
}
cdnProviders[provider.ID] = provider
}
for _, provider := range r.storageProviders.Redis {
if _, ok := redisProviders[provider.ID]; ok {
return fmt.Errorf("duplicate Redis storage provider with id '%s'", provider.ID)
}
redisProviders[provider.ID] = provider
}
var pClient persistedoperation.Client
if provider, ok := cdnProviders[r.persistedOperationsConfig.Storage.ProviderID]; ok {
if r.graphApiToken == "" {
return errors.New("graph token is required to fetch persisted operations from CDN")
}
c, err := cdn.NewClient(provider.URL, r.graphApiToken, cdn.Options{
Logger: r.logger,
})
if err != nil {
return err
}
pClient = c
r.logger.Info("Use CDN as storage provider for persisted operations",
zap.String("provider_id", provider.ID),
)
} else if provider, ok := s3Providers[r.persistedOperationsConfig.Storage.ProviderID]; ok {
c, err := s3.NewClient(provider.Endpoint, &s3.Options{
AccessKeyID: provider.AccessKey,
SecretAccessKey: provider.SecretKey,
Region: provider.Region,
UseSSL: provider.Secure,
BucketName: provider.Bucket,
ObjectPathPrefix: r.persistedOperationsConfig.Storage.ObjectPrefix,
TraceProvider: r.tracerProvider,
})
if err != nil {
return err
}
pClient = c
r.logger.Info("Use S3 as storage provider for persisted operations",
zap.String("provider_id", provider.ID),
)
} else if r.graphApiToken != "" {
if r.persistedOperationsConfig.Storage.ProviderID != "" {
return fmt.Errorf("unknown storage provider id '%s' for persisted operations", r.persistedOperationsConfig.Storage.ProviderID)
}
c, err := cdn.NewClient(r.cdnConfig.URL, r.graphApiToken, cdn.Options{
Logger: r.logger,
})
if err != nil {
return err
}
pClient = c
r.logger.Debug("Default to Cosmo CDN as persisted operations provider",
zap.String("url", r.cdnConfig.URL),
)
}
var kvClient apq.KVClient
if provider, ok := redisProviders[r.automaticPersistedQueriesConfig.Storage.ProviderID]; ok {
c, err := apq.NewRedisClient(&apq.RedisOptions{
Logger: r.logger,
StorageConfig: &provider,
Prefix: r.automaticPersistedQueriesConfig.Storage.ObjectPrefix,
})
if err != nil {
return err
}
kvClient = c
r.logger.Info("Use redis as storage provider for automatic persisted operations",
zap.String("provider_id", provider.ID),
)
}
var apqClient apq.Client
if r.automaticPersistedQueriesConfig.Enabled {
var err error
apqClient, err = apq.NewClient(&apq.Options{
Logger: r.logger,
ApqConfig: &r.automaticPersistedQueriesConfig,
KVClient: kvClient,
})
if err != nil {
return err
}
}
if pClient != nil || apqClient != nil {
// For backwards compatibility with cdn config field
cacheSize := r.persistedOperationsConfig.Cache.Size.Uint64()
if cacheSize <= 0 {
cacheSize = r.cdnConfig.CacheSize.Uint64()
}
c, err := persistedoperation.NewClient(&persistedoperation.Options{
CacheSize: cacheSize,
Logger: r.logger,
ProviderClient: pClient,
ApqClient: apqClient,
})
if err != nil {
return err
}
r.persistedOperationClient = c
}
configPoller, err := InitializeConfigPoller(r, cdnProviders, s3Providers)
if err != nil {
return err
}
if configPoller != nil {
r.configPoller = *configPoller
}
return nil
}
// Start starts the router. It does block until the router has been initialized. After that the server is listening
// on a separate goroutine. The server can be shutdown with Router.Shutdown(). Not safe for concurrent use.
// During initialization, the router will register itself with the control plane and poll the config from the CDN
// if the user opted in to connect to Cosmo Cloud.
func (r *Router) Start(ctx context.Context) error {
if r.shutdown.Load() {
return fmt.Errorf("router is shutdown. Create a new instance with router.NewRouter()")
}
if err := r.bootstrap(ctx); err != nil {
return fmt.Errorf("failed to bootstrap router: %w", err)
}
r.httpServer = newServer(&httpServerOptions{
addr: r.listenAddr,
logger: r.logger,
tlsConfig: r.tlsConfig,
tlsServerConfig: r.tlsServerConfig,
healthcheck: r.healthcheck,
baseURL: r.baseURL,
maxHeaderBytes: int(r.routerTrafficConfig.MaxHeaderBytes.Uint64()),
livenessCheckPath: r.livenessCheckPath,
readinessCheckPath: r.readinessCheckPath,
healthCheckPath: r.healthCheckPath,
})
// Start the server with the static config without polling
if r.staticExecutionConfig != nil {
if err := r.listenAndServe(); err != nil {
return err
}
if err := r.newServer(ctx, r.staticExecutionConfig); err != nil {
return err
}
defer func() {
r.httpServer.healthcheck.SetReady(true)
r.logger.Info("Server initialized and ready to serve requests",
zap.String("listen_addr", r.listenAddr),
zap.Bool("playground", r.playgroundConfig.Enabled),
zap.Bool("introspection", r.introspection),
zap.String("config_version", r.staticExecutionConfig.Version),
)
}()
if r.executionConfig != nil && r.executionConfig.Watch {
w, err := watcher.New(watcher.Options{
Logger: r.logger.With(zap.String("watcher_label", "execution_config")),
Path: r.executionConfig.Path,
Interval: r.executionConfig.WatchInterval,
Callback: func() {
if r.shutdown.Load() {
r.logger.Warn("Router is in shutdown state. Skipping config update")
return
}
data, err := os.ReadFile(r.executionConfig.Path)
if err != nil {
r.logger.Error("Failed to read config file", zap.Error(err))
return
}
r.logger.Info("Config file changed. Updating server with new config", zap.String("path", r.executionConfig.Path))
cfg, err := execution_config.UnmarshalConfig(data)
if err != nil {
r.logger.Error("Failed to unmarshal config file", zap.Error(err))
return
}
if err := r.newServer(ctx, cfg); err != nil {
r.logger.Error("Failed to update server with new config", zap.Error(err))
return
}
},
})
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
go func() {
if err := w(ctx); err != nil {
r.logger.Error("Error watching execution config", zap.Error(err))
return
}
}()
r.logger.Info("Watching config file for changes. Router will hot-reload automatically without downtime",
zap.String("path", r.executionConfig.Path),
)
return nil
}
r.logger.Info("Static execution config provided. Polling and watching is disabled. Updating execution config is only possible by restarting the router")
return nil
}
// when no static config is provided and no poller is configured, we can't start the server
if r.configPoller == nil {
return fmt.Errorf("execution config fetcher not provided. Please provide a static execution config instead")
}
cfg, err := r.configPoller.GetRouterConfig(ctx)
if err != nil {
return fmt.Errorf("failed to get initial execution config: %w", err)
}
if err := r.listenAndServe(); err != nil {
r.logger.Error("Failed to start server with initial config", zap.Error(err))
return err
}
if err := r.newServer(ctx, cfg.Config); err != nil {
return err
}
if r.playgroundConfig.Enabled {
graphqlEndpointURL, err := url.JoinPath(r.baseURL, r.graphqlPath)
if err != nil {
return fmt.Errorf("failed to join graphql endpoint url: %w", err)
}
r.logger.Info("GraphQL endpoint",
zap.String("method", http.MethodPost),
zap.String("url", graphqlEndpointURL),
)
}
/**
* Server logging after features has been initialized / disabled
*/
if r.localhostFallbackInsideDocker && docker.Inside() {
r.logger.Info("localhost fallback enabled, connections that fail to connect to localhost will be retried using host.docker.internal")
}
if r.developmentMode && r.engineExecutionConfiguration.EnableRequestTracing && r.graphApiToken == "" {
r.logger.Warn("Advanced Request Tracing (ART) is enabled in development mode but requires a graph token to work in production. For more information see https://cosmo-docs.wundergraph.com/router/advanced-request-tracing-art")
}
if r.redisClient != nil {
r.logger.Info("Rate limiting enabled",
zap.Int("rate", r.rateLimit.SimpleStrategy.Rate),
zap.Int("burst", r.rateLimit.SimpleStrategy.Burst),
zap.Duration("duration", r.Config.rateLimit.SimpleStrategy.Period),
zap.Bool("rejectExceeding", r.Config.rateLimit.SimpleStrategy.RejectExceedingRequests),
)
}
r.configPoller.Subscribe(ctx, func(newConfig *nodev1.RouterConfig, oldVersion string) error {
if r.shutdown.Load() {
r.logger.Warn("Router is in shutdown state. Skipping config update")
return nil
}
if err := r.newServer(ctx, newConfig); err != nil {
return err
}
return nil
})
// Mark the server as ready
r.httpServer.healthcheck.SetReady(true)
r.logger.Info("Server initialized and ready to serve requests",
zap.String("listen_addr", r.listenAddr),
zap.Bool("playground", r.playgroundConfig.Enabled),
zap.Bool("introspection", r.introspection),
zap.String("config_version", cfg.Config.GetVersion()),
)
return nil
}
// Shutdown gracefully shuts down the router. It blocks until the server is shutdown.
// If the router is already shutdown, the method returns immediately without error.
func (r *Router) Shutdown(ctx context.Context) (err error) {
if !r.shutdown.CompareAndSwap(false, true) {
return nil
}
// Respect grace period
if r.routerGracePeriod > 0 {
ctxWithTimer, cancel := context.WithTimeout(ctx, r.routerGracePeriod)
defer cancel()
ctx = ctxWithTimer
}
if r.configPoller != nil {
if subErr := r.configPoller.Stop(ctx); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to stop config poller: %w", subErr))
}
}
if r.httpServer != nil {
if subErr := r.httpServer.Shutdown(ctx); subErr != nil {
if errors.Is(err, context.DeadlineExceeded) {
r.logger.Warn(
"Shutdown deadline exceeded. Router took too long to shutdown. Consider increasing the grace period",
zap.Duration("grace_period", r.routerGracePeriod),
)
}
err = errors.Join(err, fmt.Errorf("failed to shutdown router: %w", subErr))
}
}
var wg sync.WaitGroup
if r.prometheusServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if subErr := r.prometheusServer.Close(); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to shutdown prometheus server: %w", subErr))
}
}()
}
if r.tracerProvider != nil {
wg.Add(1)
go func() {
defer wg.Done()
if subErr := r.tracerProvider.Shutdown(ctx); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to shutdown tracer: %w", subErr))
}
}()
}
if r.gqlMetricsExporter != nil {
wg.Add(1)
go func() {
defer wg.Done()
if subErr := r.gqlMetricsExporter.Shutdown(ctx); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to shutdown graphql metrics exporter: %w", subErr))
}
}()
}
if r.promMeterProvider != nil {
wg.Add(1)
go func() {
defer wg.Done()
if subErr := r.promMeterProvider.Shutdown(ctx); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to shutdown prometheus meter provider: %w", subErr))
}
}()
}
if r.otlpMeterProvider != nil {
wg.Add(1)
go func() {
defer wg.Done()
if subErr := r.otlpMeterProvider.Shutdown(ctx); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to shutdown OTLP meter provider: %w", subErr))
}
}()
}
if r.redisClient != nil {
wg.Add(1)
go func() {
defer wg.Done()
if closeErr := r.redisClient.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("failed to close redis client: %w", closeErr))
}
}()
}
wg.Add(1)
go func() {
defer wg.Done()
for _, module := range r.modules {
if cleaner, ok := module.(Cleaner); ok {
if subErr := cleaner.Cleanup(); subErr != nil {
err = errors.Join(err, fmt.Errorf("failed to clean module %s: %w", module.Module().ID, subErr))
}
}
}
}()
// Shutdown the CDN operation client and free up resources
if r.persistedOperationClient != nil {
r.persistedOperationClient.Close()
}
wg.Wait()
return err
}
func WithListenerAddr(addr string) Option {
return func(r *Router) {
r.listenAddr = addr
}
}
func WithLogger(logger *zap.Logger) Option {
return func(r *Router) {
r.logger = logger
}
}
func WithPlayground(enable bool) Option {
return func(r *Router) {
r.playground = enable
}
}
func WithIntrospection(enable bool) Option {
return func(r *Router) {
r.introspection = enable
}
}
func WithQueryPlans(enabled bool) Option {
return func(r *Router) {
r.queryPlansEnabled = enabled
}
}
func WithTracing(cfg *rtrace.Config) Option {
return func(r *Router) {
r.traceConfig = cfg
}
}
func WithCors(corsOpts *cors.Config) Option {
return func(r *Router) {
r.corsOptions = corsOpts
}
}
// WithMultipartHeartbeatInterval sets the interval for the engine to send heartbeats for multipart subscriptions.
func WithMultipartHeartbeatInterval(interval time.Duration) Option {
return func(r *Router) {
r.multipartHeartbeatInterval = interval
}
}
// WithGraphQLPath sets the path where the GraphQL endpoint is served.
func WithGraphQLPath(p string) Option {
return func(r *Router) {
r.graphqlPath = p
}
}
// WithGraphQLWebURL sets the URL to the GraphQL endpoint used by the GraphQL Playground.
// This is useful when the path differs from the actual GraphQL endpoint e.g. when the router is behind a reverse proxy.
// If not set, the GraphQL Playground uses the same URL as the GraphQL endpoint.
func WithGraphQLWebURL(p string) Option {
return func(r *Router) {
r.graphqlWebURL = p
}
}
// WithPlaygroundPath sets the path where the GraphQL Playground is served.
func WithPlaygroundPath(p string) Option {
return func(r *Router) {
r.playgroundPath = p
}
}
// WithPlaygroundPath sets the path where the GraphQL Playground is served.
func WithPlaygroundConfig(c config.PlaygroundConfig) Option {
return func(r *Router) {
r.playgroundConfig = c
}
}
// WithConfigPoller sets the poller client to fetch the router config. If not set, WithConfigPollerConfig should be set.
func WithConfigPoller(cf configpoller.ConfigPoller) Option {
return func(r *Router) {
r.configPoller = cf
}
}
// WithSelfRegistration sets the self registration client to register the router with the control plane.
func WithSelfRegistration(sr selfregister.SelfRegister) Option {
return func(r *Router) {
r.selfRegister = sr
}
}
// WithGracePeriod sets the grace period for the router to shutdown.
func WithGracePeriod(timeout time.Duration) Option {
return func(r *Router) {
r.routerGracePeriod = timeout
}
}
// WithMetrics sets the metrics configuration for the router.
func WithMetrics(cfg *rmetric.Config) Option {
return func(r *Router) {
r.metricConfig = cfg
}
}
// CorsDefaultOptions returns the default CORS options for the rs/cors package.
func CorsDefaultOptions() *cors.Config {
return &cors.Config{
Enabled: true,
AllowOrigins: []string{"*"},
AllowMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
},
AllowHeaders: []string{},
AllowCredentials: false,
}
}
func WithGraphApiToken(token string) Option {
return func(r *Router) {
r.graphApiToken = token
}
}
func WithModulesConfig(config map[string]interface{}) Option {
return func(r *Router) {
r.modulesConfig = config
}
}
func WithExecutionConfig(cfg *ExecutionConfig) Option {
return func(r *Router) {
r.executionConfig = cfg
}
}
// WithStaticExecutionConfig sets the static execution config. This disables polling and file watching.
func WithStaticExecutionConfig(cfg *nodev1.RouterConfig) Option {
return func(r *Router) {
r.staticExecutionConfig = cfg
}
}
// WithAwsLambdaRuntime enables the AWS Lambda behaviour.
// This flushes all telemetry data synchronously after the request is handled.
func WithAwsLambdaRuntime() Option {
return func(r *Router) {
r.awsLambda = true
}
}
func WithHealthCheckPath(path string) Option {
return func(r *Router) {
r.healthCheckPath = path
}
}
func WithHealthChecks(healthChecks health.Checker) Option {
return func(r *Router) {
r.healthcheck = healthChecks
}
}
func WithProxy(proxy ProxyFunc) Option {
return func(r *Router) {
r.proxy = proxy
}
}
func WithReadinessCheckPath(path string) Option {
return func(r *Router) {
r.readinessCheckPath = path
}
}
func WithLivenessCheckPath(path string) Option {
return func(r *Router) {
r.livenessCheckPath = path
}
}
// WithCDN sets the configuration for the CDN client
func WithCDN(cfg config.CDNConfiguration) Option {
return func(r *Router) {
r.cdnConfig = cfg
}
}
// WithEvents sets the configuration for the events client
func WithEvents(cfg config.EventsConfiguration) Option {
return func(r *Router) {
r.eventsConfig = cfg
}
}
func WithHeaderRules(headers config.HeaderRules) Option {
return func(r *Router) {
r.headerRules = &headers
}
}
func WithCacheControlPolicy(cfg config.CacheControlPolicy) Option {
return func(r *Router) {
r.cacheControlPolicy = cfg
}
}
func WithOverrideRoutingURL(overrideRoutingURL config.OverrideRoutingURLConfiguration) Option {
return func(r *Router) {
r.overrideRoutingURLConfiguration = overrideRoutingURL
}
}
func WithOverrides(overrides config.OverridesConfiguration) Option {
return func(r *Router) {
r.overrides = overrides
}
}
func WithSecurityConfig(cfg config.SecurityConfiguration) Option {
return func(r *Router) {
r.securityConfiguration = cfg
}
}
func WithEngineExecutionConfig(cfg config.EngineExecutionConfiguration) Option {
return func(r *Router) {
r.engineExecutionConfiguration = cfg
}
}
func WithCustomModules(modules ...Module) Option {
return func(r *Router) {
r.customModules = modules
}
}
func WithSubgraphTransportOptions(opts *SubgraphTransportOptions) Option {
return func(r *Router) {
r.subgraphTransportOptions = opts
}
}
func WithSubgraphRetryOptions(enabled bool, maxRetryCount int, retryMaxDuration, retryInterval time.Duration) Option {
return func(r *Router) {
r.retryOptions = retrytransport.RetryOptions{
Enabled: enabled,
MaxRetryCount: maxRetryCount,
MaxDuration: retryMaxDuration,
Interval: retryInterval,
}
}
}
func WithRouterTrafficConfig(cfg *config.RouterTrafficConfiguration) Option {
return func(r *Router) {
r.routerTrafficConfig = cfg
}
}
func WithFileUploadConfig(cfg *config.FileUpload) Option {
return func(r *Router) {
r.fileUploadConfig = cfg
}
}
func WithAccessController(controller *AccessController) Option {
return func(r *Router) {
r.accessController = controller
}
}
func WithAuthorizationConfig(cfg *config.AuthorizationConfiguration) Option {
return func(r *Router) {
r.Config.authorization = cfg
}
}
func WithRateLimitConfig(cfg *config.RateLimitConfiguration) Option {
return func(r *Router) {
r.Config.rateLimit = cfg
}
}
func WithLocalhostFallbackInsideDocker(fallback bool) Option {
return func(r *Router) {
r.localhostFallbackInsideDocker = fallback
}
}
func DefaultRouterTrafficConfig() *config.RouterTrafficConfiguration {
return &config.RouterTrafficConfiguration{
MaxRequestBodyBytes: 1000 * 1000 * 5, // 5 MB
}
}
func DefaultFileUploadConfig() *config.FileUpload {
return &config.FileUpload{
Enabled: true,
MaxFileSizeBytes: 1000 * 1000 * 50, // 50 MB,
MaxFiles: 10,
}
}
func NewTransportRequestOptions(cfg config.GlobalSubgraphRequestRule) *TransportRequestOptions {
defaults := DefaultTransportRequestOptions()
return &TransportRequestOptions{
RequestTimeout: or(cfg.RequestTimeout, defaults.RequestTimeout),
TLSHandshakeTimeout: or(cfg.TLSHandshakeTimeout, defaults.TLSHandshakeTimeout),
ResponseHeaderTimeout: or(cfg.ResponseHeaderTimeout, defaults.ResponseHeaderTimeout),
ExpectContinueTimeout: or(cfg.ExpectContinueTimeout, defaults.ExpectContinueTimeout),
KeepAliveProbeInterval: or(cfg.KeepAliveProbeInterval, defaults.KeepAliveProbeInterval),
KeepAliveIdleTimeout: or(cfg.KeepAliveIdleTimeout, defaults.KeepAliveIdleTimeout),
DialTimeout: or(cfg.DialTimeout, defaults.DialTimeout),
MaxConnsPerHost: or(cfg.MaxConnsPerHost, defaults.MaxConnsPerHost),
MaxIdleConns: or(cfg.MaxIdleConns, defaults.MaxIdleConns),
MaxIdleConnsPerHost: or(cfg.MaxIdleConnsPerHost, defaults.MaxIdleConnsPerHost),
}
}
func DefaultTransportRequestOptions() *TransportRequestOptions {
return &TransportRequestOptions{
RequestTimeout: 60 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 0 * time.Second,
ExpectContinueTimeout: 0 * time.Second,
KeepAliveProbeInterval: 30 * time.Second,
KeepAliveIdleTimeout: 0 * time.Second,
DialTimeout: 30 * time.Second,
MaxConnsPerHost: 100,
MaxIdleConns: 1024,
MaxIdleConnsPerHost: 20,
}
}
func NewSubgraphTransportOptions(cfg config.TrafficShapingRules) *SubgraphTransportOptions {
base := &SubgraphTransportOptions{
TransportRequestOptions: NewTransportRequestOptions(cfg.All),
SubgraphMap: map[string]*TransportRequestOptions{},
}
for k, v := range cfg.Subgraphs {
base.SubgraphMap[k] = NewTransportRequestOptions(*v)
}
return base
}
func DefaultSubgraphTransportOptions() *SubgraphTransportOptions {
return &SubgraphTransportOptions{
TransportRequestOptions: DefaultTransportRequestOptions(),
SubgraphMap: map[string]*TransportRequestOptions{},
}
}
func DefaultGraphQLMetricsConfig() *GraphQLMetricsConfig {
return &GraphQLMetricsConfig{
Enabled: false,
CollectorEndpoint: "",
}
}
func WithGraphQLMetrics(cfg *GraphQLMetricsConfig) Option {
return func(r *Router) {
r.graphqlMetricsConfig = cfg
}
}
// WithDevelopmentMode enables development mode. This should only be used for testing purposes.
// Development mode allows e.g. to use ART (Advanced Request Tracing) without request signing.
func WithDevelopmentMode(enabled bool) Option {
return func(r *Router) {
r.developmentMode = enabled
}
}
func WithClusterName(name string) Option {
return func(r *Router) {
r.clusterName = name
}
}
func WithInstanceID(id string) Option {
return func(r *Router) {
r.instanceID = id
}
}
func WithConfigVersionHeader(include bool) Option {
return func(r *Router) {
r.setConfigVersionHeader = include
}
}
func WithAnonymization(ipConfig *IPAnonymizationConfig) Option {
return func(r *Router) {
r.ipAnonymization = ipConfig
}
}
func WithWebSocketConfiguration(cfg *config.WebSocketConfiguration) Option {
return func(r *Router) {
r.Config.webSocketConfiguration = cfg
}
}
func WithSubgraphErrorPropagation(cfg config.SubgraphErrorPropagationConfiguration) Option {
return func(r *Router) {
r.Config.subgraphErrorPropagation = cfg
}
}
func WithAccessLogs(cfg *AccessLogsConfig) Option {
return func(r *Router) {
r.accessLogsConfig = cfg
}
}
func WithTLSConfig(cfg *TlsConfig) Option {
return func(r *Router) {
r.tlsConfig = cfg
}
}
func WithTelemetryAttributes(attributes []config.CustomAttribute) Option {
return func(r *Router) {
r.telemetryAttributes = attributes
}
}
func WithConfigPollerConfig(cfg *RouterConfigPollerConfig) Option {
return func(r *Router) {
r.routerConfigPollerConfig = cfg
}
}
func WithPersistedOperationsConfig(cfg config.PersistedOperationsConfig) Option {
return func(r *Router) {
r.persistedOperationsConfig = cfg
}
}
func WithAutomatedPersistedQueriesConfig(cfg config.AutomaticPersistedQueriesConfig) Option {
return func(r *Router) {
r.automaticPersistedQueriesConfig = cfg
}
}
func WithApolloCompatibilityFlagsConfig(cfg config.ApolloCompatibilityFlags) Option {
return func(r *Router) {
if cfg.EnableAll {
cfg.ValueCompletion.Enabled = true
cfg.TruncateFloats.Enabled = true
cfg.SuppressFetchErrors.Enabled = true
cfg.ReplaceUndefinedOpFieldErrors.Enabled = true
cfg.ReplaceInvalidVarErrors.Enabled = true
cfg.ReplaceValidationErrorStatus.Enabled = true
cfg.SubscriptionMultipartPrintBoundary.Enabled = true
}
r.apolloCompatibilityFlags = cfg
}
}
func WithApolloRouterCompatibilityFlags(cfg config.ApolloRouterCompatibilityFlags) Option {
return func(r *Router) {
r.apolloRouterCompatibilityFlags = cfg
}
}
func WithStorageProviders(cfg config.StorageProviders) Option {
return func(r *Router) {
r.storageProviders = cfg
}
}
func WithClientHeader(cfg config.ClientHeader) Option {
return func(r *Router) {
r.clientHeader = cfg
}
}
func WithCacheWarmupConfig(cfg *config.CacheWarmupConfiguration) Option {
return func(r *Router) {
r.cacheWarmup = cfg
}
}
type ProxyFunc func(req *http.Request) (*url.URL, error)
func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc) *http.Transport {
dialer := &net.Dialer{
Timeout: opts.DialTimeout,
KeepAlive: opts.KeepAliveProbeInterval,
}
// Great source of inspiration: https://gitlab.com/gitlab-org/gitlab-pages
// A pages proxy in go that handles tls to upstreams, rate limiting, and more
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
},
// The defaults value 0 = unbounded.
// We set to some value to prevent resource exhaustion e.g max requests and ports.
MaxConnsPerHost: opts.MaxConnsPerHost,
// The defaults value 0 = unbounded. 100 is used by the default go transport.
// This value should be significant higher than MaxIdleConnsPerHost.
MaxIdleConns: opts.MaxIdleConns,
// The default value is 2. Such a low limit will open and close connections too often.
// Details: https://gitlab.com/gitlab-org/gitlab-pages/-/merge_requests/274
MaxIdleConnsPerHost: opts.MaxIdleConnsPerHost,
ForceAttemptHTTP2: true,
IdleConnTimeout: opts.KeepAliveIdleTimeout,
// Set more timeouts https://gitlab.com/gitlab-org/gitlab-pages/-/issues/495
TLSHandshakeTimeout: opts.TLSHandshakeTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
ExpectContinueTimeout: opts.ExpectContinueTimeout,
// Will return nil when HTTP(S)_PROXY does not exist or is empty.
// This will prevent the transport from handling the proxy when it is not needed.
Proxy: proxy,
}
}
func TraceConfigFromTelemetry(cfg *config.Telemetry) *rtrace.Config {
var exporters []*rtrace.ExporterConfig
for _, exp := range cfg.Tracing.Exporters {
exporters = append(exporters, &rtrace.ExporterConfig{
Disabled: exp.Disabled,
Endpoint: exp.Endpoint,
Exporter: exp.Exporter,
BatchTimeout: exp.BatchTimeout,
ExportTimeout: exp.ExportTimeout,
Headers: exp.Headers,
HTTPPath: exp.HTTPPath,
})
}
var propagators []rtrace.Propagator
if cfg.Tracing.Propagation.TraceContext {
propagators = append(propagators, rtrace.PropagatorTraceContext)
}
if cfg.Tracing.Propagation.B3 {
propagators = append(propagators, rtrace.PropagatorB3)
}
if cfg.Tracing.Propagation.Jaeger {
propagators = append(propagators, rtrace.PropagatorJaeger)
}
if cfg.Tracing.Propagation.Datadog {
propagators = append(propagators, rtrace.PropagatorDatadog)
}
if cfg.Tracing.Propagation.Baggage {
propagators = append(propagators, rtrace.PropagatorBaggage)
}
return &rtrace.Config{
Enabled: cfg.Tracing.Enabled,
Name: cfg.ServiceName,
Version: Version,
Sampler: cfg.Tracing.SamplingRate,
ParentBasedSampler: cfg.Tracing.ParentBasedSampler,
WithNewRoot: cfg.Tracing.WithNewRoot,
Attributes: nil,
ExportGraphQLVariables: rtrace.ExportGraphQLVariables{
Enabled: cfg.Tracing.ExportGraphQLVariables,
},
ResourceAttributes: buildResourceAttributes(cfg.ResourceAttributes),
Exporters: exporters,
Propagators: propagators,
ResponseTraceHeader: cfg.Tracing.ResponseTraceHeader,
}
}
// buildAttributesMap returns a map of custom attributes to quickly check if a field is used in the custom attributes.
func buildAttributesMap(attributes []config.CustomAttribute) map[string]string {
result := make(map[string]string)
for _, attr := range attributes {
if attr.ValueFrom != nil && attr.ValueFrom.ContextField != "" {
result[attr.ValueFrom.ContextField] = attr.Key
}
}
return result
}
// buildHeaderAttributesMapper returns a function that maps custom attributes to the request headers.
func buildHeaderAttributesMapper(attributes []config.CustomAttribute) func(req *http.Request) []attribute.KeyValue {
if len(attributes) == 0 {
return nil
}
return func(req *http.Request) []attribute.KeyValue {
var result []attribute.KeyValue
for _, attr := range attributes {
if attr.ValueFrom != nil {
if req != nil && attr.ValueFrom.RequestHeader != "" {
hv := req.Header.Get(attr.ValueFrom.RequestHeader)
if hv != "" {
result = append(result, attribute.String(attr.Key, hv))
} else if attr.Default != "" {
result = append(result, attribute.String(attr.Key, attr.Default))
}
} else if attr.Default != "" {
result = append(result, attribute.String(attr.Key, attr.Default))
}
} else if attr.Default != "" {
result = append(result, attribute.String(attr.Key, attr.Default))
}
}
return result
}
}
func buildResourceAttributes(attributes []config.CustomStaticAttribute) []attribute.KeyValue {
var result []attribute.KeyValue
for _, attr := range attributes {
result = append(result, attribute.String(attr.Key, attr.Value))
}
r := attribute.NewSet(result...)
return r.ToSlice()
}
func MetricConfigFromTelemetry(cfg *config.Telemetry) *rmetric.Config {
var openTelemetryExporters []*rmetric.OpenTelemetryExporter
for _, exp := range cfg.Metrics.OTLP.Exporters {
openTelemetryExporters = append(openTelemetryExporters, &rmetric.OpenTelemetryExporter{
Disabled: exp.Disabled,
Endpoint: exp.Endpoint,
Exporter: exp.Exporter,
Headers: exp.Headers,
HTTPPath: exp.HTTPPath,
Temporality: exp.Temporality,
})
}
return &rmetric.Config{
Name: cfg.ServiceName,
Version: Version,
Attributes: cfg.Metrics.Attributes,
ResourceAttributes: buildResourceAttributes(cfg.ResourceAttributes),
OpenTelemetry: rmetric.OpenTelemetry{
Enabled: cfg.Metrics.OTLP.Enabled,
RouterRuntime: cfg.Metrics.OTLP.RouterRuntime,
GraphqlCache: cfg.Metrics.OTLP.GraphqlCache,
EngineStats: rmetric.EngineStatsConfig{
Subscription: cfg.Metrics.OTLP.EngineStats.Subscriptions,
},
Exporters: openTelemetryExporters,
ExcludeMetrics: cfg.Metrics.OTLP.ExcludeMetrics,
ExcludeMetricLabels: cfg.Metrics.OTLP.ExcludeMetricLabels,
},
Prometheus: rmetric.PrometheusConfig{
Enabled: cfg.Metrics.Prometheus.Enabled,
ListenAddr: cfg.Metrics.Prometheus.ListenAddr,
Path: cfg.Metrics.Prometheus.Path,
GraphqlCache: cfg.Metrics.Prometheus.GraphqlCache,
EngineStats: rmetric.EngineStatsConfig{
Subscription: cfg.Metrics.Prometheus.EngineStats.Subscriptions,
},
ExcludeMetrics: cfg.Metrics.Prometheus.ExcludeMetrics,
ExcludeMetricLabels: cfg.Metrics.Prometheus.ExcludeMetricLabels,
ExcludeScopeInfo: cfg.Metrics.Prometheus.ExcludeScopeInfo,
},
}
}
func or[T any](maybe *T, or T) T {
if maybe != nil {
return *maybe
}
return or
}