router/core/graphql_prehandler.go (775 lines of code) (raw):

package core import ( "bytes" "context" "crypto/ecdsa" "encoding/json" "errors" "fmt" "maps" "net/http" "slices" "strconv" "strings" "sync" "time" "go.opentelemetry.io/otel/codes" "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" "go.opentelemetry.io/otel/attribute" otelmetric "go.opentelemetry.io/otel/metric" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/graphqlerrors" "github.com/wundergraph/cosmo/router/internal/expr" "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/pkg/art" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/otel" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" ) type PreHandlerOptions struct { Logger *zap.Logger Executor *Executor Metrics RouterMetrics OperationProcessor *OperationProcessor Planner *OperationPlanner AccessController *AccessController OperationBlocker *OperationBlocker RouterPublicKey *ecdsa.PublicKey TracerProvider *sdktrace.TracerProvider ComplexityLimits *config.ComplexityLimits MaxUploadFiles int MaxUploadFileSize int FlushTelemetryAfterResponse bool FileUploadEnabled bool TraceExportVariables bool DevelopmentMode bool EnableRequestTracing bool AlwaysIncludeQueryPlan bool AlwaysSkipLoader bool QueryPlansEnabled bool QueryPlansLoggingEnabled bool TrackSchemaUsageInfo bool ClientHeader config.ClientHeader ComputeOperationSha256 bool ApolloCompatibilityFlags *config.ApolloCompatibilityFlags DisableVariablesRemapping bool } type PreHandler struct { log *zap.Logger executor *Executor metrics RouterMetrics operationProcessor *OperationProcessor planner *OperationPlanner accessController *AccessController operationBlocker *OperationBlocker developmentMode bool alwaysIncludeQueryPlan bool alwaysSkipLoader bool queryPlansEnabled bool // queryPlansEnabled is a flag to enable query plans output in the extensions queryPlansLoggingEnabled bool // queryPlansLoggingEnabled is a flag to enable logging of query plans routerPublicKey *ecdsa.PublicKey enableRequestTracing bool tracerProvider *sdktrace.TracerProvider flushTelemetryAfterResponse bool tracer trace.Tracer traceExportVariables bool fileUploadEnabled bool maxUploadFiles int maxUploadFileSize int complexityLimits *config.ComplexityLimits trackSchemaUsageInfo bool clientHeader config.ClientHeader computeOperationSha256 bool apolloCompatibilityFlags *config.ApolloCompatibilityFlags variableParsePool astjson.ParserPool disableVariablesRemapping bool } type httpOperation struct { requestContext *requestContext body []byte files []*httpclient.FileUpload requestLogger *zap.Logger routerSpan trace.Span operationMetrics *OperationMetrics traceTimings *art.TraceTimings } func NewPreHandler(opts *PreHandlerOptions) *PreHandler { return &PreHandler{ log: opts.Logger, executor: opts.Executor, metrics: opts.Metrics, operationProcessor: opts.OperationProcessor, planner: opts.Planner, accessController: opts.AccessController, operationBlocker: opts.OperationBlocker, routerPublicKey: opts.RouterPublicKey, developmentMode: opts.DevelopmentMode, enableRequestTracing: opts.EnableRequestTracing, flushTelemetryAfterResponse: opts.FlushTelemetryAfterResponse, tracerProvider: opts.TracerProvider, traceExportVariables: opts.TraceExportVariables, tracer: opts.TracerProvider.Tracer( "wundergraph/cosmo/router/pre_handler", trace.WithInstrumentationVersion("0.0.1"), ), fileUploadEnabled: opts.FileUploadEnabled, maxUploadFiles: opts.MaxUploadFiles, maxUploadFileSize: opts.MaxUploadFileSize, complexityLimits: opts.ComplexityLimits, alwaysIncludeQueryPlan: opts.AlwaysIncludeQueryPlan, alwaysSkipLoader: opts.AlwaysSkipLoader, queryPlansEnabled: opts.QueryPlansEnabled, queryPlansLoggingEnabled: opts.QueryPlansLoggingEnabled, trackSchemaUsageInfo: opts.TrackSchemaUsageInfo, clientHeader: opts.ClientHeader, computeOperationSha256: opts.ComputeOperationSha256, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, disableVariablesRemapping: opts.DisableVariablesRemapping, } } func (h *PreHandler) getBodyReadBuffer(preferredSize int64) *bytes.Buffer { if preferredSize <= 0 { preferredSize = 1024 * 4 // 4KB } else if preferredSize > h.operationProcessor.maxOperationSizeInBytes { preferredSize = h.operationProcessor.maxOperationSizeInBytes } return bytes.NewBuffer(make([]byte, 0, preferredSize)) } // Error and Status Code handling // // When a server receives a well-formed GraphQL-over-HTTP request, it must return a // well‐formed GraphQL response. The server's response describes the result of validating // and executing the requested operation if successful, and describes any errors encountered // during the request. This means working errors should be returned as part of the response body. // That also implies parsing or validation errors. They should be returned as part of the response body. // Only in cases where the request is malformed or invalid GraphQL should the server return an HTTP 4xx or 5xx error code. // https://github.com/graphql/graphql-over-http/blob/main/spec/GraphQLOverHTTP.md#response func (h *PreHandler) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ( // In GraphQL the statusCode does not always express the error state of the request // we use this flag to determine if we have an error for the request metrics writtenBytes int statusCode = http.StatusOK traceTimings *art.TraceTimings ) requestContext := getRequestContext(r.Context()) requestLogger := requestContext.logger routerSpan := trace.SpanFromContext(r.Context()) clientInfo := NewClientInfoFromRequest(r, h.clientHeader) requestContext.telemetry.addCommonAttribute( otel.WgClientName.String(clientInfo.Name), otel.WgClientVersion.String(clientInfo.Version), otel.WgOperationProtocol.String(OperationProtocolHTTP.String()), ) startAttrs := *requestContext.telemetry.AcquireAttributes() startAttrs = append(startAttrs, requestContext.telemetry.metricAttrs...) metrics := h.metrics.StartOperation( requestLogger, r.ContentLength, requestContext.telemetry.metricSliceAttrs, otelmetric.WithAttributeSet(attribute.NewSet(startAttrs...)), ) requestContext.telemetry.ReleaseAttributes(&startAttrs) routerSpan.SetAttributes(requestContext.telemetry.traceAttrs...) if requestContext.telemetry.traceAttributeExpressions != nil { traceMetrics, err := requestContext.telemetry.traceAttributeExpressions.expressionsAttributes(requestContext) if err != nil { requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) } requestContext.telemetry.addCommonAttribute( traceMetrics..., ) routerSpan.SetAttributes(traceMetrics...) } if requestContext.telemetry.metricAttributeExpressions != nil { metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributes(requestContext) if err != nil { requestLogger.Error("failed to resolve metric attribute", zap.Error(err)) } requestContext.telemetry.addMetricAttribute( metricAttrs..., ) } requestContext.operation = &operationContext{ clientInfo: clientInfo, } defer func() { requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldGraphQLErrorServices, requestContext.graphQLErrorServices) requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldOperationServices, requestContext.dataSourceNames) requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldGraphQLErrorCodes, requestContext.graphQLErrorCodes) metrics.Finish( requestContext, statusCode, writtenBytes, h.flushTelemetryAfterResponse, ) if h.flushTelemetryAfterResponse { h.flushMetrics(r.Context(), requestLogger) } }() executionOptions, traceOptions, err := h.parseRequestOptions(r, clientInfo, requestLogger) if err != nil { requestContext.SetError(err) writeRequestErrors(r, w, http.StatusBadRequest, graphqlerrors.RequestErrorsFromError(err), requestLogger) return } requestContext.operation.protocol = OperationProtocolHTTP requestContext.operation.executionOptions = executionOptions requestContext.operation.traceOptions = traceOptions if traceOptions.Enable { r = r.WithContext(resolve.SetTraceStart(r.Context(), traceOptions.EnablePredictableDebugTimings)) traceTimings = art.NewTraceTimings(r.Context()) } var body []byte var files []*httpclient.FileUpload if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { if !h.fileUploadEnabled { requestContext.SetError(&httpGraphqlError{ message: "file upload disabled", statusCode: http.StatusOK, }) writeOperationError(r, w, requestLogger, requestContext.error) return } _, readMultiPartSpan := h.tracer.Start(r.Context(), "HTTP - Read Multipart", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) multipartParser := NewMultipartParser(h.operationProcessor, h.maxUploadFiles, h.maxUploadFileSize) var err error body, files, err = multipartParser.Parse(r, h.getBodyReadBuffer(r.ContentLength)) if err != nil { requestContext.SetError(err) writeOperationError(r, w, requestLogger, requestContext.error) readMultiPartSpan.End() return } readMultiPartSpan.SetAttributes( otel.HTTPRequestUploadFileCount.Int(len(files)), ) readMultiPartSpan.End() // Cleanup all files. Needs to be called in the pre_handler function to ensure that // defer is called after the response is written defer func() { if err := multipartParser.RemoveAll(); err != nil { requestLogger.Error("Failed to remove files after multipart request", zap.Error(err)) } }() } else if r.Method == http.MethodPost { _, readOperationBodySpan := h.tracer.Start(r.Context(), "HTTP - Read Body", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) var err error body, err = h.operationProcessor.ReadBody(r.Body, h.getBodyReadBuffer(r.ContentLength)) if err != nil { requestContext.SetError(err) // Don't produce errors logs here because it can only be client side errors // e.g. too large body, slow client, aborted connection etc. // The error is logged as debug log in the writeOperationError function writeOperationError(r, w, requestLogger, err) readOperationBodySpan.End() return } readOperationBodySpan.End() } variablesParser := h.variableParsePool.Get() defer h.variableParsePool.Put(variablesParser) // If we have authenticators, we try to authenticate the request if h.accessController != nil { _, authenticateSpan := h.tracer.Start(r.Context(), "Authenticate", trace.WithSpanKind(trace.SpanKindServer), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) validatedReq, err := h.accessController.Access(w, r) if err != nil { requestContext.SetError(err) requestLogger.Error("Failed to authenticate request", zap.Error(err)) // Mark the root span of the router as failed, so we can easily identify failed requests rtrace.AttachErrToSpan(routerSpan, err) rtrace.AttachErrToSpan(authenticateSpan, err) authenticateSpan.End() writeOperationError(r, w, requestLogger, &httpGraphqlError{ message: err.Error(), statusCode: http.StatusUnauthorized, }) return } authenticateSpan.End() r = validatedReq requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context()) } if requestContext.telemetry.traceAttributeExpressions != nil { traceMetrics, err := requestContext.telemetry.traceAttributeExpressions.expressionsAttributesWithAuth(requestContext) if err != nil { requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) } requestContext.telemetry.addCommonAttribute( traceMetrics..., ) routerSpan.SetAttributes(traceMetrics...) } if requestContext.telemetry.metricAttributeExpressions != nil { metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributesWithAuth(requestContext) if err != nil { requestLogger.Error("failed to resolve metric attribute", zap.Error(err)) } requestContext.telemetry.addMetricAttribute( metricAttrs..., ) } err = h.handleOperation(r, variablesParser, &httpOperation{ requestContext: requestContext, requestLogger: requestLogger, routerSpan: routerSpan, operationMetrics: metrics, traceTimings: traceTimings, files: files, body: body, }) if err != nil { requestContext.SetError(err) // Mark the root span of the router as failed, so we can easily identify failed requests rtrace.AttachErrToSpan(routerSpan, err) writeOperationError(r, w, requestLogger, err) return } art.SetRequestTracingStats(r.Context(), traceOptions, traceTimings) if traceOptions.Enable { reqData := &resolve.RequestData{ Method: r.Method, URL: r.URL.String(), Headers: r.Header, Body: resolve.BodyData{ Query: requestContext.operation.rawContent, OperationName: requestContext.operation.name, Variables: json.RawMessage(requestContext.operation.variables.String()), }, } r = r.WithContext(resolve.SetRequest(r.Context(), reqData)) } ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) // The request context needs to be updated with the latest request to ensure that the context is up to date requestContext.request = r requestContext.responseWriter = ww // Call the final handler that resolves the operation // and enrich the context to make it available in the request context as well for metrics etc. next.ServeHTTP(ww, r) statusCode = ww.Status() writtenBytes = ww.BytesWritten() // Mark the root span of the router as failed, so we can easily identify failed requests if requestContext.error != nil { rtrace.AttachErrToSpan(trace.SpanFromContext(r.Context()), requestContext.error) } }) } func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit) bool { if h.computeOperationSha256 { return true } hasPersistedHash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil && operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash != "" // If it already has a persisted hash attached to the request, then there is no need for us to compute it anew // Otherwise, we only want to compute the hash (an expensive operation) if we're safelisting or logging unknown persisted operations return !hasPersistedHash && (h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled) } // shouldFetchPersistedOperation determines if we should fetch a persisted operation. The most intuitive case is if the // operation is a persisted operation. However, we also want to fetch persisted operations if we're enabling safelisting // and if we're logging unknown operations. This is because we want to check if the operation is already persisted in the cache func (h *PreHandler) shouldFetchPersistedOperation(operationKit *OperationKit) bool { return operationKit.parsedOperation.IsPersistedOperation || h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled } func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson.Parser, httpOperation *httpOperation) error { operationKit, err := h.operationProcessor.NewKit() if err != nil { return err } defer func() { // the kit must be freed before we're doing io operations // the kit is bound to the number of CPUs, and we must not hold onto it while doing IO operations // it needs to be called inside a defer to ensure it is called in panic situations as well if operationKit != nil { operationKit.Free() } }() requestContext := httpOperation.requestContext // Handle the case when operation information are provided as GET parameters if req.Method == http.MethodGet { if err := operationKit.UnmarshalOperationFromURL(req.URL); err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error parsing request query params: %s", err), statusCode: http.StatusBadRequest, } } } else if req.Method == http.MethodPost { if err := operationKit.UnmarshalOperationFromBody(httpOperation.body); err != nil { return &httpGraphqlError{ message: "error parsing request body", statusCode: http.StatusBadRequest, } } // If we have files, we need to set them on the parsed operation if len(httpOperation.files) > 0 { requestContext.operation.files = httpOperation.files } } // Compute the operation sha256 hash as soon as possible for observability reasons if h.shouldComputeOperationSha256(operationKit) { if err := operationKit.ComputeOperationSha256(); err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error hashing operation: %s", err), statusCode: http.StatusInternalServerError, } } requestContext.operation.sha256Hash = operationKit.parsedOperation.Sha256Hash requestContext.telemetry.addCustomMetricStringAttr(ContextFieldOperationSha256, requestContext.operation.sha256Hash) if h.operationBlocker.SafelistEnabled || h.operationBlocker.LogUnknownOperationsEnabled { // Set the request hash to the parsed hash, to see if it matches a persisted operation operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery = &GraphQLRequestExtensionsPersistedQuery{ Sha256Hash: operationKit.parsedOperation.Sha256Hash, } } } requestContext.operation.extensions = operationKit.parsedOperation.Request.Extensions requestContext.operation.variables, err = variablesParser.ParseBytes(operationKit.parsedOperation.Request.Variables) if err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error parsing variables: %s", err), statusCode: http.StatusBadRequest, } } var ( skipParse bool isApq bool ) if h.shouldFetchPersistedOperation(operationKit) { ctx, span := h.tracer.Start(req.Context(), "Load Persisted Operation", trace.WithSpanKind(trace.SpanKindClient), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) skipParse, isApq, err = operationKit.FetchPersistedOperation(ctx, requestContext.operation.clientInfo) span.SetAttributes(otel.WgEnginePersistedOperationCacheHit.Bool(operationKit.parsedOperation.PersistedOperationCacheHit)) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError if h.operationBlocker.LogUnknownOperationsEnabled && errors.As(err, &poNotFoundErr) { requestContext.logger.Warn("Unknown persisted operation found", zap.String("query", operationKit.parsedOperation.Request.Query), zap.String("sha256Hash", poNotFoundErr.Sha256Hash)) if h.operationBlocker.SafelistEnabled { span.End() return err } } else { span.End() return err } } span.End() requestContext.operation.persistedOperationCacheHit = operationKit.parsedOperation.PersistedOperationCacheHit } // If the persistent operation is already in the cache, we skip the parse step // because the operation was already parsed. This is a performance optimization, and we // can do it because we know that the persisted operation is immutable (identified by the hash) if !skipParse { _, engineParseSpan := h.tracer.Start(req.Context(), "Operation - Parse", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) httpOperation.traceTimings.StartParse() startParsing := time.Now() err = operationKit.Parse() if err != nil { rtrace.AttachErrToSpan(engineParseSpan, err) requestContext.operation.parsingTime = time.Since(startParsing) if !requestContext.operation.traceOptions.ExcludeParseStats { httpOperation.traceTimings.EndParse() } engineParseSpan.End() return err } requestContext.operation.parsingTime = time.Since(startParsing) if !requestContext.operation.traceOptions.ExcludeParseStats { httpOperation.traceTimings.EndParse() } engineParseSpan.End() } requestContext.operation.name = operationKit.parsedOperation.Request.OperationName requestContext.operation.opType = operationKit.parsedOperation.Type attributesAfterParse := []attribute.KeyValue{ otel.WgOperationName.String(operationKit.parsedOperation.Request.OperationName), otel.WgOperationType.String(operationKit.parsedOperation.Type), } requestContext.telemetry.addCommonAttribute(attributesAfterParse...) // Set the router span name after we have the operation name httpOperation.routerSpan.SetName(GetSpanName(operationKit.parsedOperation.Request.OperationName, operationKit.parsedOperation.Type)) if req.Method == http.MethodGet && operationKit.parsedOperation.Type == "mutation" { return &httpGraphqlError{ message: "Mutations can only be sent over HTTP POST", statusCode: http.StatusMethodNotAllowed, } } // Set the operation name and type to the operation metrics and the router span as early as possible httpOperation.routerSpan.SetAttributes(attributesAfterParse...) if err := h.operationBlocker.OperationIsBlocked(requestContext.logger, requestContext.expressionContext, operationKit.parsedOperation); err != nil { return &httpGraphqlError{ message: err.Error(), statusCode: http.StatusOK, } } if operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil && operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash != "" { requestContext.operation.persistedID = operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash persistedIDAttribute := otel.WgOperationPersistedID.String(operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) requestContext.telemetry.addCommonAttribute(persistedIDAttribute) httpOperation.routerSpan.SetAttributes(persistedIDAttribute) } /** * Normalize the operation */ if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.StartNormalize() } startNormalization := time.Now() _, engineNormalizeSpan := h.tracer.Start(req.Context(), "Operation - Normalize", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) cached, err := operationKit.NormalizeOperation(requestContext.operation.clientInfo.Name, isApq) if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } engineNormalizeSpan.End() return err } // Set the cache hit attribute on the span engineNormalizeSpan.SetAttributes(otel.WgNormalizationCacheHit.Bool(cached)) requestContext.operation.normalizationCacheHit = operationKit.parsedOperation.NormalizationCacheHit /** * Normalize the variables */ // Normalize the variables returns list of uploads mapping if there are any of them present in a query // type UploadPathMapping struct { // VariableName string - is a variable name holding the direct or nested value of type Upload, example "f" // OriginalUploadPath string - is a path relative to variables which have an Upload type, example "variables.f" // NewUploadPath string - if variable was used in the inline object like this `arg: {f: $f}` this field will hold the new extracted path, example "variables.a.f", if it is an empty, there was no change in the path // } uploadsMapping, err := operationKit.NormalizeVariables() if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } engineNormalizeSpan.End() return err } // update file uploads path if they were used in nested field in the extracted variables for mapping := range slices.Values(uploadsMapping) { // if the NewUploadPath is empty it means that there was no change in the path - e.g. upload was directly passed to the argument // e.g. field(fileArgument: $file) will result in []UploadPathMapping{ {VariableName: "file", OriginalUploadPath: "variables.file", NewUploadPath: ""} } if mapping.NewUploadPath == "" { continue } // look for the corresponding file which was used in the nested argument // we are matching original upload path passed via uploads map with the mapping items idx := slices.IndexFunc(requestContext.operation.files, func(file *httpclient.FileUpload) bool { return file.VariablePath() == mapping.OriginalUploadPath }) if idx == -1 { continue } // if NewUploadPath is not empty the file argument was used in the nested object, and we need to update the path // e.g. field(arg: {file: $file}) normalized to field(arg: $a) will result in []UploadPathMapping{ {VariableName: "file", OriginalUploadPath: "variables.file", NewUploadPath: "variables.a.file"} } // so "variables.file" should be updated to "variables.a.file" requestContext.operation.files[idx].SetVariablePath(uploadsMapping[idx].NewUploadPath) } // RemapVariables is updating and sort variables name to be able to have them in a predictable order // after remapping requestContext.operation.remapVariables map will contain new names as a keys and old names as a values - to be able to extract the old values // because it does not rename variables in a variables json err = operationKit.RemapVariables(h.disableVariablesRemapping) if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } engineNormalizeSpan.End() return err } requestContext.operation.hash = operationKit.parsedOperation.ID requestContext.operation.internalHash = operationKit.parsedOperation.InternalID requestContext.operation.remapVariables = operationKit.parsedOperation.RemapVariables if !h.disableVariablesRemapping && len(uploadsMapping) > 0 { // after variables remapping we need to update the file uploads path because variables relative path has changed // but files still references the old uploads locations // key `to` is a new variable name // value `from` is an old variable name // we are looping through remapped variables to find a match between old variable name and variable which was holding an upload for to, from := range maps.All(requestContext.operation.remapVariables) { // loop over upload mappings to find a match between variable name and upload variable name for uploadMapping := range slices.Values(uploadsMapping) { if uploadMapping.VariableName != from { continue } uploadPath := uploadMapping.NewUploadPath // if NewUploadPath is empty it means that there was no change in the path - e.g. upload was directly passed to the argument if uploadPath == "" { uploadPath = uploadMapping.OriginalUploadPath } // next step is to compare file upload path with the original upload path from the upload mappings for file := range slices.Values(requestContext.operation.files) { if file.VariablePath() != uploadPath { continue } // trim old variable name prefix oldUploadPathPrefix := fmt.Sprintf("variables.%s.", from) relativeUploadPath := strings.TrimPrefix(uploadPath, oldUploadPathPrefix) // set new variable name prefix updatedPath := fmt.Sprintf("variables.%s.%s", to, relativeUploadPath) file.SetVariablePath(updatedPath) } } } } operationHashString := strconv.FormatUint(operationKit.parsedOperation.ID, 10) operationHashAttribute := otel.WgOperationHash.String(operationHashString) requestContext.telemetry.addCommonAttribute(operationHashAttribute) httpOperation.routerSpan.SetAttributes(operationHashAttribute) requestContext.operation.rawContent = operationKit.parsedOperation.Request.Query requestContext.operation.content = operationKit.parsedOperation.NormalizedRepresentation requestContext.operation.variables, err = variablesParser.ParseBytes(operationKit.parsedOperation.Request.Variables) if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } engineNormalizeSpan.End() return err } requestContext.operation.normalizationTime = time.Since(startNormalization) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } engineNormalizeSpan.End() if operationKit.parsedOperation.IsPersistedOperation { engineNormalizeSpan.SetAttributes(otel.WgEnginePersistedOperationCacheHit.Bool(operationKit.parsedOperation.PersistedOperationCacheHit)) } if h.traceExportVariables { // At this stage the variables are normalized httpOperation.routerSpan.SetAttributes(otel.WgOperationVariables.String(string(operationKit.parsedOperation.Request.Variables))) } // Set the normalized operation only on the root span operationContentAttribute := otel.WgOperationContent.String(operationKit.parsedOperation.NormalizedRepresentation) httpOperation.routerSpan.SetAttributes(operationContentAttribute) /** * Validate the operation */ if !requestContext.operation.traceOptions.ExcludeValidateStats { httpOperation.traceTimings.StartValidate() } startValidation := time.Now() _, engineValidateSpan := h.tracer.Start(req.Context(), "Operation - Validate", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) validationCached, err := operationKit.Validate(requestContext.operation.executionOptions.SkipLoader, requestContext.operation.remapVariables, h.apolloCompatibilityFlags) if err != nil { rtrace.AttachErrToSpan(engineValidateSpan, err) requestContext.operation.validationTime = time.Since(startValidation) if !requestContext.operation.traceOptions.ExcludeValidateStats { httpOperation.traceTimings.EndValidate() } engineValidateSpan.End() return err } engineValidateSpan.SetAttributes(otel.WgValidationCacheHit.Bool(validationCached)) if requestContext.operation.executionOptions.SkipLoader { // In case we're skipping the loader, which means that we won't execute the operation // we skip the validation of variables as we're not using them // this allows us to generate query plans without having to provide variables engineValidateSpan.SetAttributes(otel.WgVariablesValidationSkipped.Bool(true)) } // Validate that the planned query doesn't exceed the maximum query depth configured // This check runs if they've configured a max query depth, and it can optionally be turned off for persisted operations if h.complexityLimits != nil { cacheHit, complexityCalcs, queryDepthErr := operationKit.ValidateQueryComplexity(h.complexityLimits, operationKit.kit.doc, h.executor.RouterSchema, operationKit.parsedOperation.IsPersistedOperation) engineValidateSpan.SetAttributes(otel.WgQueryDepth.Int(complexityCalcs.Depth)) engineValidateSpan.SetAttributes(otel.WgQueryTotalFields.Int(complexityCalcs.TotalFields)) engineValidateSpan.SetAttributes(otel.WgQueryRootFields.Int(complexityCalcs.RootFields)) engineValidateSpan.SetAttributes(otel.WgQueryRootFieldAliases.Int(complexityCalcs.RootFieldAliases)) engineValidateSpan.SetAttributes(otel.WgQueryDepthCacheHit.Bool(cacheHit)) if queryDepthErr != nil { rtrace.AttachErrToSpan(engineValidateSpan, err) requestContext.operation.validationTime = time.Since(startValidation) httpOperation.traceTimings.EndValidate() engineValidateSpan.End() return queryDepthErr } } requestContext.operation.validationTime = time.Since(startValidation) httpOperation.traceTimings.EndValidate() engineValidateSpan.End() /** * Plan the operation */ // If the request has a query parameter wg_trace=true we skip the cache // and always plan the operation // this allows us to "write" to the plan if !requestContext.operation.traceOptions.ExcludePlannerStats { httpOperation.traceTimings.StartPlanning() } startPlanning := time.Now() _, enginePlanSpan := h.tracer.Start(req.Context(), "Operation - Plan", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(otel.WgEngineRequestTracingEnabled.Bool(requestContext.operation.traceOptions.Enable)), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) planOptions := PlanOptions{ ClientInfo: requestContext.operation.clientInfo, TraceOptions: requestContext.operation.traceOptions, ExecutionOptions: requestContext.operation.executionOptions, TrackSchemaUsageInfo: h.trackSchemaUsageInfo, } err = h.planner.plan(requestContext.operation, planOptions) if err != nil { httpOperation.requestLogger.Error("failed to plan operation", zap.Error(err)) rtrace.AttachErrToSpan(enginePlanSpan, err) if !requestContext.operation.traceOptions.ExcludePlannerStats { httpOperation.traceTimings.EndPlanning() } enginePlanSpan.End() return err } enginePlanSpan.SetAttributes(otel.WgEnginePlanCacheHit.Bool(requestContext.operation.planCacheHit)) requestContext.operation.planningTime = time.Since(startPlanning) httpOperation.traceTimings.EndPlanning() enginePlanSpan.End() planningAttrs := *requestContext.telemetry.AcquireAttributes() planningAttrs = append(planningAttrs, otel.WgEnginePlanCacheHit.Bool(requestContext.operation.planCacheHit)) planningAttrs = append(planningAttrs, requestContext.telemetry.metricAttrs...) httpOperation.operationMetrics.routerMetrics.MetricStore().MeasureOperationPlanningTime( req.Context(), requestContext.operation.planningTime, requestContext.telemetry.metricSliceAttrs, otelmetric.WithAttributeSet(attribute.NewSet(planningAttrs...)), ) requestContext.telemetry.ReleaseAttributes(&planningAttrs) // we could log the query plan only if query plans are calculated if (h.queryPlansEnabled && requestContext.operation.executionOptions.IncludeQueryPlanInResponse) || h.alwaysIncludeQueryPlan { switch p := requestContext.operation.preparedPlan.preparedPlan.(type) { case *plan.SynchronousResponsePlan: p.Response.Fetches.NormalizedQuery = operationKit.parsedOperation.NormalizedRepresentation } if h.queryPlansLoggingEnabled { switch p := requestContext.operation.preparedPlan.preparedPlan.(type) { case *plan.SynchronousResponsePlan: printedPlan := p.Response.Fetches.QueryPlan().PrettyPrint() if h.developmentMode { h.log.Sugar().Debugf("Query Plan:\n%s", printedPlan) } else { h.log.Debug("Query Plan", zap.String("query_plan", printedPlan)) } } } } return nil } // flushMetrics flushes all metrics to the respective exporters // only used for serverless router build func (h *PreHandler) flushMetrics(ctx context.Context, requestLogger *zap.Logger) { requestLogger.Debug("Flushing metrics ...") now := time.Now() wg := &sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() if err := h.metrics.MetricStore().Flush(ctx); err != nil { requestLogger.Error("Failed to flush OTEL metrics", zap.Error(err)) } }() wg.Add(1) go func() { defer wg.Done() if err := h.tracerProvider.ForceFlush(ctx); err != nil { requestLogger.Error("Failed to flush OTEL tracer", zap.Error(err)) } }() wg.Wait() requestLogger.Debug("Metrics flushed", zap.Duration("duration", time.Since(now))) } func (h *PreHandler) parseRequestOptions(r *http.Request, clientInfo *ClientInfo, requestLogger *zap.Logger) (resolve.ExecutionOptions, resolve.TraceOptions, error) { ex, tr, err := h.internalParseRequestOptions(r, clientInfo, requestLogger) if err != nil { return ex, tr, err } if h.alwaysIncludeQueryPlan { ex.IncludeQueryPlanInResponse = true } if h.alwaysSkipLoader { ex.SkipLoader = true } if !h.queryPlansEnabled { ex.IncludeQueryPlanInResponse = false } return ex, tr, nil } func (h *PreHandler) internalParseRequestOptions(r *http.Request, clientInfo *ClientInfo, requestLogger *zap.Logger) (resolve.ExecutionOptions, resolve.TraceOptions, error) { // Determine if we should enable request tracing / query plans at all if h.enableRequestTracing { // In dev mode we always allow to enable tracing / query plans if h.developmentMode { return h.parseRequestExecutionOptions(r), h.parseRequestTraceOptions(r), nil } // If the client has a valid request token, and we have a public key from the controlplane if clientInfo.WGRequestToken != "" && h.routerPublicKey != nil { _, err := jwt.Parse(clientInfo.WGRequestToken, func(token *jwt.Token) (interface{}, error) { return h.routerPublicKey, nil }, jwt.WithValidMethods([]string{jwt.SigningMethodES256.Name})) if err != nil { requestLogger.Error(fmt.Sprintf("failed to parse request token: %s", err.Error())) return resolve.ExecutionOptions{}, resolve.TraceOptions{}, err } return h.parseRequestExecutionOptions(r), h.parseRequestTraceOptions(r), nil } } // Disable tracing / query plans for all other cases traceOptions := resolve.TraceOptions{} traceOptions.DisableAll() return resolve.ExecutionOptions{ SkipLoader: false, IncludeQueryPlanInResponse: false, }, traceOptions, nil } func (h *PreHandler) parseRequestExecutionOptions(r *http.Request) resolve.ExecutionOptions { options := resolve.ExecutionOptions{ SkipLoader: false, IncludeQueryPlanInResponse: false, } if r.Header.Get("X-WG-Skip-Loader") != "" { options.SkipLoader = true } if r.URL.Query().Has("wg_skip_loader") { options.SkipLoader = true } if r.Header.Get("X-WG-Include-Query-Plan") != "" { options.IncludeQueryPlanInResponse = true } if r.URL.Query().Has("wg_include_query_plan") { options.IncludeQueryPlanInResponse = true } return options }