router/core/graphql_handler.go (378 lines of code) (raw):

package core import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" rErrors "github.com/wundergraph/cosmo/router/internal/errors" rotel "github.com/wundergraph/cosmo/router/pkg/otel" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/graphqlerrors" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/statistics" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) var ( errCouldNotResolveResponse = errors.New("could not resolve response") errInternalServer = errors.New("internal server error") errCouldNotFlushResponse = errors.New("could not flush response") errOperationPlanUnsupported = errors.New("unsupported operation plan") ) const ( ExecutionPlanCacheHeader = "X-WG-Execution-Plan-Cache" PersistedOperationCacheHeader = "X-WG-Persisted-Operation-Cache" NormalizationCacheHeader = "X-WG-Normalization-Cache" ) type ReportError interface { error Report() *operationreport.Report } type reportError struct { report *operationreport.Report } func (e *reportError) Error() string { if len(e.report.InternalErrors) > 0 { return errors.Join(e.report.InternalErrors...).Error() } messages := make([]string, len(e.report.ExternalErrors)) for i, e := range e.report.ExternalErrors { messages[i] = e.Message } return strings.Join(messages, ", ") } func (e *reportError) Report() *operationreport.Report { return e.report } type HandlerOptions struct { Executor *Executor Log *zap.Logger EnableExecutionPlanCacheResponseHeader bool EnablePersistedOperationCacheResponseHeader bool EnableNormalizationCacheResponseHeader bool EnableResponseHeaderPropagation bool EngineStats statistics.EngineStatistics TracerProvider trace.TracerProvider Authorizer *CosmoAuthorizer RateLimiter *CosmoRateLimiter RateLimitConfig *config.RateLimitConfiguration SubgraphErrorPropagation config.SubgraphErrorPropagationConfiguration EngineLoaderHooks resolve.LoaderHooks ApolloSubscriptionMultipartPrintBoundary bool } func NewGraphQLHandler(opts HandlerOptions) *GraphQLHandler { graphQLHandler := &GraphQLHandler{ log: opts.Log, executor: opts.Executor, enableExecutionPlanCacheResponseHeader: opts.EnableExecutionPlanCacheResponseHeader, enablePersistedOperationCacheResponseHeader: opts.EnablePersistedOperationCacheResponseHeader, enableNormalizationCacheResponseHeader: opts.EnableNormalizationCacheResponseHeader, enableResponseHeaderPropagation: opts.EnableResponseHeaderPropagation, engineStats: opts.EngineStats, tracer: opts.TracerProvider.Tracer( "wundergraph/cosmo/router/graphql_handler", trace.WithInstrumentationVersion("0.0.1"), ), authorizer: opts.Authorizer, rateLimiter: opts.RateLimiter, rateLimitConfig: opts.RateLimitConfig, subgraphErrorPropagation: opts.SubgraphErrorPropagation, engineLoaderHooks: opts.EngineLoaderHooks, apolloSubscriptionMultipartPrintBoundary: opts.ApolloSubscriptionMultipartPrintBoundary, } return graphQLHandler } // Error and Status Code handling // // When a server receives a well-formed GraphQL-over-HTTP request, it must return a // well‐formed GraphQL response. The server's response describes the result of validating // and executing the requested operation if successful, and describes any errors encountered // during the request. This means working errors should be returned as part of the response body. // That also implies parsing or validation errors. They should be returned as part of the response body. // Only in cases where the request is malformed or invalid GraphQL should the server return an HTTP 4xx or 5xx error code. // https://github.com/graphql/graphql-over-http/blob/main/spec/GraphQLOverHTTP.md#response type GraphQLHandler struct { log *zap.Logger executor *Executor engineStats statistics.EngineStatistics tracer trace.Tracer authorizer *CosmoAuthorizer rateLimiter *CosmoRateLimiter rateLimitConfig *config.RateLimitConfiguration subgraphErrorPropagation config.SubgraphErrorPropagationConfiguration engineLoaderHooks resolve.LoaderHooks enableExecutionPlanCacheResponseHeader bool enablePersistedOperationCacheResponseHeader bool enableNormalizationCacheResponseHeader bool enableResponseHeaderPropagation bool apolloSubscriptionMultipartPrintBoundary bool } func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestContext := getRequestContext(r.Context()) executionContext, graphqlExecutionSpan := h.tracer.Start(r.Context(), "Operation - Execute", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) defer graphqlExecutionSpan.End() ctx := &resolve.Context{ Variables: requestContext.operation.variables, RemapVariables: requestContext.operation.remapVariables, Files: requestContext.operation.files, Request: resolve.Request{ Header: r.Header, }, RenameTypeNames: h.executor.RenameTypeNames, TracingOptions: requestContext.operation.traceOptions, InitialPayload: requestContext.operation.initialPayload, Extensions: requestContext.operation.extensions, ExecutionOptions: requestContext.operation.executionOptions, } ctx = ctx.WithContext(executionContext) if h.authorizer != nil { ctx = WithAuthorizationExtension(ctx) ctx.SetAuthorizer(h.authorizer) } if h.engineLoaderHooks != nil { ctx.SetEngineLoaderHooks(h.engineLoaderHooks) } ctx = h.configureRateLimiting(ctx) switch p := requestContext.operation.preparedPlan.preparedPlan.(type) { case *plan.SynchronousResponsePlan: w.Header().Set("Content-Type", "application/json") h.setDebugCacheHeaders(w, requestContext.operation) if h.enableResponseHeaderPropagation { ctx = WithResponseHeaderPropagation(ctx) } defer propagateSubgraphErrors(ctx) resp, err := h.executor.Resolver.ResolveGraphQLResponse(ctx, p.Response, nil, HeaderPropagationWriter(w, ctx.Context())) requestContext.dataSourceNames = getSubgraphNames(p.Response.DataSources) if err != nil { trackFinalResponseError(ctx.Context(), err) h.WriteError(ctx, err, p.Response, w) return } graphqlExecutionSpan.SetAttributes(rotel.WgAcquireResolverWaitTimeMs.Int64(resp.ResolveAcquireWaitTime.Milliseconds())) case *plan.SubscriptionResponsePlan: var ( writer resolve.SubscriptionResponseWriter ok bool ) h.setDebugCacheHeaders(w, requestContext.operation) defer propagateSubgraphErrors(ctx) ctx, writer, ok = GetSubscriptionResponseWriter(ctx, r, w, h.apolloSubscriptionMultipartPrintBoundary) if !ok { requestContext.logger.Error("unable to get subscription response writer", zap.Error(errCouldNotFlushResponse)) trackFinalResponseError(r.Context(), errCouldNotFlushResponse) writeRequestErrors(r, w, http.StatusInternalServerError, graphqlerrors.RequestErrorsFromError(errCouldNotFlushResponse), requestContext.logger) return } if !ctx.ExecutionOptions.SkipLoader { h.engineStats.ConnectionsInc() defer h.engineStats.ConnectionsDec() } err := h.executor.Resolver.ResolveGraphQLSubscription(ctx, p.Response, writer) requestContext.dataSourceNames = getSubgraphNames(p.Response.Response.DataSources) if err != nil { if errors.Is(err, context.Canceled) { requestContext.logger.Debug("context canceled: unable to resolve subscription response", zap.Error(err)) trackFinalResponseError(r.Context(), err) return } else if errors.Is(err, ErrUnauthorized) { trackFinalResponseError(ctx.Context(), err) writeRequestErrors(r, w, http.StatusUnauthorized, graphqlerrors.RequestErrorsFromError(err), requestContext.logger) return } requestContext.logger.Error("unable to resolve subscription response", zap.Error(err)) trackFinalResponseError(ctx.Context(), err) writeRequestErrors(r, w, http.StatusInternalServerError, graphqlerrors.RequestErrorsFromError(errCouldNotResolveResponse), requestContext.logger) return } default: requestContext.logger.Error("unsupported plan kind") trackFinalResponseError(ctx.Context(), errOperationPlanUnsupported) writeRequestErrors(r, w, http.StatusInternalServerError, graphqlerrors.RequestErrorsFromError(errOperationPlanUnsupported), requestContext.logger) } } func (h *GraphQLHandler) configureRateLimiting(ctx *resolve.Context) *resolve.Context { if h.rateLimiter == nil { return ctx } if h.rateLimitConfig == nil { return ctx } if !h.rateLimitConfig.Enabled { return ctx } if h.rateLimitConfig.Strategy != "simple" { return ctx } ctx.SetRateLimiter(h.rateLimiter) ctx.RateLimitOptions = resolve.RateLimitOptions{ Enable: true, IncludeStatsInResponseExtension: !h.rateLimitConfig.SimpleStrategy.HideStatsFromResponseExtension, Rate: h.rateLimitConfig.SimpleStrategy.Rate, Burst: h.rateLimitConfig.SimpleStrategy.Burst, Period: h.rateLimitConfig.SimpleStrategy.Period, RateLimitKey: h.rateLimitConfig.Storage.KeyPrefix, RejectExceedingRequests: h.rateLimitConfig.SimpleStrategy.RejectExceedingRequests, ErrorExtensionCode: resolve.RateLimitErrorExtensionCode{ Enabled: h.rateLimitConfig.ErrorExtensionCode.Enabled, Code: h.rateLimitConfig.ErrorExtensionCode.Code, }, } return WithRateLimiterStats(ctx) } // WriteError writes the error to the response writer. This function must be concurrency-safe. // @TODO This function should be refactored to be a helper function for websocket and http error writing // In the websocket case, we call this function concurrently as part of the polling loop. This is error-prone. func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolve.GraphQLResponse, w io.Writer) { reqContext := getRequestContext(ctx.Context()) if reqContext == nil { h.log.Error("unable to get request context") return } requestLogger := reqContext.logger httpWriter, isHttpResponseWriter := w.(http.ResponseWriter) response := GraphQLErrorResponse{ Errors: make([]graphqlError, 1), Data: nil, } switch getErrorType(err) { case errorTypeMergeResult: var errMerge resolve.ErrMergeResult if !errors.As(err, &errMerge) { response.Errors[0].Message = "Internal server error" return } response.Errors[0].Message = errMerge.Error() case errorTypeRateLimit: response.Errors[0].Message = "Rate limit exceeded" if h.rateLimitConfig.ErrorExtensionCode.Enabled { response.Errors[0].Extensions = &Extensions{ Code: h.rateLimitConfig.ErrorExtensionCode.Code, } } if !h.rateLimitConfig.SimpleStrategy.HideStatsFromResponseExtension { buf := bytes.NewBuffer(make([]byte, 0, 1024)) err = h.rateLimiter.RenderResponseExtension(ctx, buf) if err != nil { requestLogger.Error("unable to render rate limit stats", zap.Error(err)) if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } return } response.Extensions = &Extensions{ RateLimit: buf.Bytes(), } } if isHttpResponseWriter { httpWriter.WriteHeader(h.rateLimiter.RejectStatusCode()) } case errorTypeUnauthorized: response.Errors[0].Message = "Unauthorized" if h.authorizer.HasResponseExtensionData(ctx) { buf := bytes.NewBuffer(make([]byte, 0, 1024)) err = h.authorizer.RenderResponseExtension(ctx, buf) if err != nil { requestLogger.Error("unable to render authorization extension", zap.Error(err)) if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } return } response.Extensions = &Extensions{ Authorization: buf.Bytes(), } } if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusOK) // Always return 200 OK when we return a well-formed response } case errorTypeContextCanceled: response.Errors[0].Message = "Client disconnected" if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusRequestTimeout) } case errorTypeContextTimeout: response.Errors[0].Message = "Server timeout" if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusRequestTimeout) } case errorTypeUnknown: response.Errors[0].Message = "Internal server error" if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } case errorTypeUpgradeFailed: var upgradeErr *graphql_datasource.UpgradeRequestError if h.subgraphErrorPropagation.PropagateStatusCodes && errors.As(err, &upgradeErr) && upgradeErr.StatusCode != 0 { response.Errors[0].Extensions = &Extensions{ StatusCode: upgradeErr.StatusCode, } if subgraph := reqContext.subgraphResolver.BySubgraphURL(upgradeErr.URL); subgraph != nil { response.Errors[0].Message = fmt.Sprintf("Subscription Upgrade request failed for Subgraph '%s'.", subgraph.Name) } else { response.Errors[0].Message = "Subscription Upgrade request failed" } } else { response.Errors[0].Message = "Subscription Upgrade request failed" } if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusOK) } case errorTypeEDFS: response.Errors[0].Message = fmt.Sprintf("EDFS error: %s", err.Error()) if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } case errorTypeInvalidWsSubprotocol: response.Errors[0].Message = fmt.Sprintf("Invalid Subprotocol error: %s or configure the subprotocol to be used using `wgc subgraph update` command.", err.Error()) if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } case errorTypeEDFSInvalidMessage: response.Errors[0].Message = "Invalid message received" if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } } if ctx.TracingOptions.Enable && ctx.TracingOptions.IncludeTraceOutputInResponseExtensions { traceNode := resolve.GetTrace(ctx.Context(), res.Fetches) if response.Extensions == nil { response.Extensions = &Extensions{} } response.Extensions.Trace, err = json.Marshal(traceNode) if err != nil { requestLogger.Error("Unable to marshal trace node", zap.Error(err)) } } err = json.NewEncoder(w).Encode(response) if err != nil { if rErrors.IsBrokenPipe(err) { requestLogger.Warn("Broken pipe, unable to write error response", zap.Error(err)) } else { requestLogger.Error("Unable to write error response", zap.Error(err)) } } if wsRw, ok := w.(*websocketResponseWriter); ok { _ = wsRw.Flush() } } func (h *GraphQLHandler) setDebugCacheHeaders(w http.ResponseWriter, opCtx *operationContext) { if h.enableNormalizationCacheResponseHeader { if opCtx.normalizationCacheHit { w.Header().Set(NormalizationCacheHeader, "HIT") } else { w.Header().Set(NormalizationCacheHeader, "MISS") } } if h.enablePersistedOperationCacheResponseHeader { if opCtx.persistedOperationCacheHit { w.Header().Set(PersistedOperationCacheHeader, "HIT") } else { w.Header().Set(PersistedOperationCacheHeader, "MISS") } } if h.enableExecutionPlanCacheResponseHeader { if opCtx.planCacheHit { w.Header().Set(ExecutionPlanCacheHeader, "HIT") } else { w.Header().Set(ExecutionPlanCacheHeader, "MISS") } } }