func()

in router/core/graphql_prehandler.go [168:437]


func (h *PreHandler) Handler(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

		var (
			// In GraphQL the statusCode does not always express the error state of the request
			// we use this flag to determine if we have an error for the request metrics
			writtenBytes int
			statusCode   = http.StatusOK
			traceTimings *art.TraceTimings
		)

		requestContext := getRequestContext(r.Context())
		requestLogger := requestContext.logger

		routerSpan := trace.SpanFromContext(r.Context())

		clientInfo := NewClientInfoFromRequest(r, h.clientHeader)

		requestContext.telemetry.addCommonAttribute(
			otel.WgClientName.String(clientInfo.Name),
			otel.WgClientVersion.String(clientInfo.Version),
			otel.WgOperationProtocol.String(OperationProtocolHTTP.String()),
		)

		startAttrs := *requestContext.telemetry.AcquireAttributes()
		startAttrs = append(startAttrs, requestContext.telemetry.metricAttrs...)

		metrics := h.metrics.StartOperation(
			requestLogger,
			r.ContentLength,
			requestContext.telemetry.metricSliceAttrs,
			otelmetric.WithAttributeSet(attribute.NewSet(startAttrs...)),
		)

		requestContext.telemetry.ReleaseAttributes(&startAttrs)

		routerSpan.SetAttributes(requestContext.telemetry.traceAttrs...)

		if requestContext.telemetry.traceAttributeExpressions != nil {
			traceMetrics, err := requestContext.telemetry.traceAttributeExpressions.expressionsAttributes(requestContext)
			if err != nil {
				requestLogger.Error("failed to resolve trace attribute", zap.Error(err))
			}
			requestContext.telemetry.addCommonAttribute(
				traceMetrics...,
			)
			routerSpan.SetAttributes(traceMetrics...)
		}

		if requestContext.telemetry.metricAttributeExpressions != nil {
			metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributes(requestContext)
			if err != nil {
				requestLogger.Error("failed to resolve metric attribute", zap.Error(err))
			}
			requestContext.telemetry.addMetricAttribute(
				metricAttrs...,
			)
		}

		requestContext.operation = &operationContext{
			clientInfo: clientInfo,
		}

		defer func() {
			requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldGraphQLErrorServices, requestContext.graphQLErrorServices)
			requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldOperationServices, requestContext.dataSourceNames)
			requestContext.telemetry.AddCustomMetricStringSliceAttr(ContextFieldGraphQLErrorCodes, requestContext.graphQLErrorCodes)

			metrics.Finish(
				requestContext,
				statusCode,
				writtenBytes,
				h.flushTelemetryAfterResponse,
			)

			if h.flushTelemetryAfterResponse {
				h.flushMetrics(r.Context(), requestLogger)
			}
		}()

		executionOptions, traceOptions, err := h.parseRequestOptions(r, clientInfo, requestLogger)
		if err != nil {
			requestContext.SetError(err)
			writeRequestErrors(r, w, http.StatusBadRequest, graphqlerrors.RequestErrorsFromError(err), requestLogger)
			return
		}

		requestContext.operation.protocol = OperationProtocolHTTP
		requestContext.operation.executionOptions = executionOptions
		requestContext.operation.traceOptions = traceOptions

		if traceOptions.Enable {
			r = r.WithContext(resolve.SetTraceStart(r.Context(), traceOptions.EnablePredictableDebugTimings))
			traceTimings = art.NewTraceTimings(r.Context())
		}

		var body []byte
		var files []*httpclient.FileUpload

		if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
			if !h.fileUploadEnabled {
				requestContext.SetError(&httpGraphqlError{
					message:    "file upload disabled",
					statusCode: http.StatusOK,
				})
				writeOperationError(r, w, requestLogger, requestContext.error)
				return
			}

			_, readMultiPartSpan := h.tracer.Start(r.Context(), "HTTP - Read Multipart",
				trace.WithSpanKind(trace.SpanKindInternal),
				trace.WithAttributes(requestContext.telemetry.traceAttrs...),
			)

			multipartParser := NewMultipartParser(h.operationProcessor, h.maxUploadFiles, h.maxUploadFileSize)

			var err error
			body, files, err = multipartParser.Parse(r, h.getBodyReadBuffer(r.ContentLength))
			if err != nil {
				requestContext.SetError(err)
				writeOperationError(r, w, requestLogger, requestContext.error)
				readMultiPartSpan.End()
				return
			}

			readMultiPartSpan.SetAttributes(
				otel.HTTPRequestUploadFileCount.Int(len(files)),
			)

			readMultiPartSpan.End()

			// Cleanup all files. Needs to be called in the pre_handler function to ensure that
			// defer is called after the response is written
			defer func() {
				if err := multipartParser.RemoveAll(); err != nil {
					requestLogger.Error("Failed to remove files after multipart request", zap.Error(err))
				}
			}()
		} else if r.Method == http.MethodPost {
			_, readOperationBodySpan := h.tracer.Start(r.Context(), "HTTP - Read Body",
				trace.WithSpanKind(trace.SpanKindInternal),
				trace.WithAttributes(requestContext.telemetry.traceAttrs...),
			)

			var err error
			body, err = h.operationProcessor.ReadBody(r.Body, h.getBodyReadBuffer(r.ContentLength))
			if err != nil {
				requestContext.SetError(err)

				// Don't produce errors logs here because it can only be client side errors
				// e.g. too large body, slow client, aborted connection etc.
				// The error is logged as debug log in the writeOperationError function

				writeOperationError(r, w, requestLogger, err)
				readOperationBodySpan.End()
				return
			}

			readOperationBodySpan.End()
		}

		variablesParser := h.variableParsePool.Get()
		defer h.variableParsePool.Put(variablesParser)

		// If we have authenticators, we try to authenticate the request
		if h.accessController != nil {
			_, authenticateSpan := h.tracer.Start(r.Context(), "Authenticate",
				trace.WithSpanKind(trace.SpanKindServer),
				trace.WithAttributes(requestContext.telemetry.traceAttrs...),
			)

			validatedReq, err := h.accessController.Access(w, r)
			if err != nil {
				requestContext.SetError(err)
				requestLogger.Error("Failed to authenticate request", zap.Error(err))

				// Mark the root span of the router as failed, so we can easily identify failed requests
				rtrace.AttachErrToSpan(routerSpan, err)
				rtrace.AttachErrToSpan(authenticateSpan, err)

				authenticateSpan.End()

				writeOperationError(r, w, requestLogger, &httpGraphqlError{
					message:    err.Error(),
					statusCode: http.StatusUnauthorized,
				})
				return
			}

			authenticateSpan.End()

			r = validatedReq

			requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context())
		}

		if requestContext.telemetry.traceAttributeExpressions != nil {
			traceMetrics, err := requestContext.telemetry.traceAttributeExpressions.expressionsAttributesWithAuth(requestContext)
			if err != nil {
				requestLogger.Error("failed to resolve trace attribute", zap.Error(err))
			}
			requestContext.telemetry.addCommonAttribute(
				traceMetrics...,
			)
			routerSpan.SetAttributes(traceMetrics...)
		}

		if requestContext.telemetry.metricAttributeExpressions != nil {
			metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributesWithAuth(requestContext)
			if err != nil {
				requestLogger.Error("failed to resolve metric attribute", zap.Error(err))
			}
			requestContext.telemetry.addMetricAttribute(
				metricAttrs...,
			)
		}

		err = h.handleOperation(r, variablesParser, &httpOperation{
			requestContext:   requestContext,
			requestLogger:    requestLogger,
			routerSpan:       routerSpan,
			operationMetrics: metrics,
			traceTimings:     traceTimings,
			files:            files,
			body:             body,
		})
		if err != nil {
			requestContext.SetError(err)
			// Mark the root span of the router as failed, so we can easily identify failed requests
			rtrace.AttachErrToSpan(routerSpan, err)

			writeOperationError(r, w, requestLogger, err)
			return
		}

		art.SetRequestTracingStats(r.Context(), traceOptions, traceTimings)

		if traceOptions.Enable {
			reqData := &resolve.RequestData{
				Method:  r.Method,
				URL:     r.URL.String(),
				Headers: r.Header,
				Body: resolve.BodyData{
					Query:         requestContext.operation.rawContent,
					OperationName: requestContext.operation.name,
					Variables:     json.RawMessage(requestContext.operation.variables.String()),
				},
			}
			r = r.WithContext(resolve.SetRequest(r.Context(), reqData))
		}

		ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)

		// The request context needs to be updated with the latest request to ensure that the context is up to date
		requestContext.request = r
		requestContext.responseWriter = ww

		// Call the final handler that resolves the operation
		// and enrich the context to make it available in the request context as well for metrics etc.
		next.ServeHTTP(ww, r)

		statusCode = ww.Status()
		writtenBytes = ww.BytesWritten()

		// Mark the root span of the router as failed, so we can easily identify failed requests
		if requestContext.error != nil {
			rtrace.AttachErrToSpan(trace.SpanFromContext(r.Context()), requestContext.error)
		}
	})
}