router/core/transport.go (331 lines of code) (raw):

package core import ( "bytes" "fmt" "io" "net/http" "net/url" "sort" "strconv" "sync" "go.opentelemetry.io/otel/propagation" otelmetric "go.opentelemetry.io/otel/metric" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "go.opentelemetry.io/otel/attribute" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.21.0" "github.com/wundergraph/cosmo/router/pkg/metric" "github.com/wundergraph/cosmo/router/pkg/otel" "github.com/wundergraph/cosmo/router/pkg/trace" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" otrace "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "github.com/wundergraph/cosmo/router/internal/docker" "github.com/wundergraph/cosmo/router/internal/retrytransport" ) type ( TransportPreHandler func(req *http.Request, ctx RequestContext) (*http.Request, *http.Response) TransportPostHandler func(resp *http.Response, ctx RequestContext) *http.Response ) type CustomTransport struct { roundTripper http.RoundTripper preHandlers []TransportPreHandler postHandlers []TransportPostHandler metricStore metric.Store logger *zap.Logger sf map[uint64]*sfCacheItem sfMu *sync.RWMutex } type sfCacheItem struct { loaded chan struct{} response *http.Response body []byte err error } func NewCustomTransport( logger *zap.Logger, roundTripper http.RoundTripper, retryOptions retrytransport.RetryOptions, metricStore metric.Store, enableSingleFlight bool, ) *CustomTransport { ct := &CustomTransport{ metricStore: metricStore, } if retryOptions.Enabled { ct.roundTripper = retrytransport.NewRetryHTTPTransport(roundTripper, retryOptions, logger) } else { ct.roundTripper = roundTripper } if enableSingleFlight { ct.sf = make(map[uint64]*sfCacheItem) ct.sfMu = &sync.RWMutex{} } return ct } func (ct *CustomTransport) measureSubgraphMetrics(req *http.Request) func(err error, resp *http.Response) { reqContext := getRequestContext(req.Context()) activeSubgraph := reqContext.ActiveSubgraph(req) attributes := *reqContext.telemetry.AcquireAttributes() if activeSubgraph != nil { attributes = append(attributes, otel.WgSubgraphName.String(activeSubgraph.Name), otel.WgSubgraphID.String(activeSubgraph.Id), ) } attributes = append(attributes, reqContext.telemetry.metricAttrs...) if reqContext.telemetry.metricAttributeExpressions != nil { additionalAttrs, err := reqContext.telemetry.metricAttributeExpressions.expressionsAttributes(reqContext) if err != nil { ct.logger.Error("failed to resolve metric attribute expressions", zap.Error(err)) } attributes = append(attributes, additionalAttrs...) } o := otelmetric.WithAttributeSet(attribute.NewSet(attributes...)) inFlightDone := ct.metricStore.MeasureInFlight(req.Context(), reqContext.telemetry.metricSliceAttrs, o) ct.metricStore.MeasureRequestSize(req.Context(), req.ContentLength, reqContext.telemetry.metricSliceAttrs, o) return func(err error, resp *http.Response) { defer reqContext.telemetry.ReleaseAttributes(&attributes) inFlightDone() if resp != nil { attributes = append(attributes, semconv.HTTPStatusCode(resp.StatusCode)) o = otelmetric.WithAttributeSet(attribute.NewSet(attributes...)) ct.metricStore.MeasureResponseSize(req.Context(), resp.ContentLength, reqContext.telemetry.metricSliceAttrs, o) } } } // RoundTrip of the engine upstream requests. The handler is called concurrently for each request. // Be aware that multiple modules can be active at the same time. Must be concurrency safe. func (ct *CustomTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { moduleContext := &moduleRequestContext{ requestContext: getRequestContext(req.Context()), sendError: nil, } done := ct.measureSubgraphMetrics(req) defer func() { done(err, resp) }() if ct.preHandlers != nil { for _, preHandler := range ct.preHandlers { r, resp := preHandler(req, moduleContext) // Non nil response means the handler decided to skip sending the request if resp != nil { return resp, nil } req = r } } if !ct.allowSingleFlight(req) { resp, err = ct.roundTripper.RoundTrip(req) } else { resp, err = ct.roundTripSingleFlight(req) } // Set the error on the request context so that it can be checked by the post handlers if err != nil { moduleContext.sendError = err } if ct.postHandlers != nil { for _, postHandler := range ct.postHandlers { newResp := postHandler(resp, moduleContext) // Abort with the first handler that returns a non-nil response if newResp != nil { return newResp, nil } } } if err != nil { return nil, err } return resp, err } func (ct *CustomTransport) allowSingleFlight(req *http.Request) bool { if ct.sf == nil { // Single flight is disabled return false } if req.Header.Get("Upgrade") != "" { // Websocket requests are not idempotent return false } if req.Header.Get("Accept") == "text/event-stream" { // SSE requests are not idempotent return false } if resolve.SingleFlightDisallowed(req.Context()) { // Single flight is disallowed for this request (e.g. because it is a Mutation) return false } return true } func (ct *CustomTransport) roundTripSingleFlight(req *http.Request) (*http.Response, error) { key := ct.singleFlightKey(req) ct.sfMu.RLock() item, shared := ct.sf[key] ct.sfMu.RUnlock() sfStats := resolve.GetSingleFlightStats(req.Context()) if sfStats != nil { sfStats.SingleFlightUsed = true sfStats.SingleFlightSharedResponse = shared } if shared { select { case <-item.loaded: case <-req.Context().Done(): return nil, req.Context().Err() } // If the single flight item has an error, return it immediately // This happens e.g. on network errors if item.err != nil { return nil, item.err } res := &http.Response{} res.Status = item.response.Status res.StatusCode = item.response.StatusCode res.Header = item.response.Header res.Trailer = item.response.Trailer res.ContentLength = item.response.ContentLength res.TransferEncoding = item.response.TransferEncoding res.Close = item.response.Close res.Uncompressed = item.response.Uncompressed res.Request = req // Restore the body res.Body = io.NopCloser(bytes.NewReader(item.body)) return res, item.err } if sfStats != nil { sfStats.SingleFlightUsed = true sfStats.SingleFlightSharedResponse = false } item = &sfCacheItem{ loaded: make(chan struct{}), } ct.sfMu.Lock() ct.sf[key] = item ct.sfMu.Unlock() defer func() { close(item.loaded) ct.sfMu.Lock() delete(ct.sf, key) ct.sfMu.Unlock() }() res, err := ct.roundTripper.RoundTrip(req) if err != nil { item.err = err return nil, err } defer res.Body.Close() item.body, err = io.ReadAll(res.Body) if err != nil { item.err = err return nil, err } item.response = res // Restore the body res.Body = io.NopCloser(bytes.NewReader(item.body)) return res, nil } func (ct *CustomTransport) singleFlightKey(req *http.Request) uint64 { keyGen := pool.Hash64.Get() defer pool.Hash64.Put(keyGen) if bodyHash, ok := httpclient.BodyHashFromContext(req.Context()); ok { _, _ = keyGen.WriteString(strconv.FormatUint(bodyHash, 10)) } unsortedHeaders := make([]string, 0, len(req.Header)) for key := range req.Header { value := req.Header.Get(key) unsortedHeaders = append(unsortedHeaders, key+value) } sort.Strings(unsortedHeaders) for i := range unsortedHeaders { _, _ = keyGen.WriteString(unsortedHeaders[i]) } sum := keyGen.Sum64() return sum } type TransportFactory struct { preHandlers []TransportPreHandler postHandlers []TransportPostHandler subgraphTransportOptions *SubgraphTransportOptions retryOptions retrytransport.RetryOptions localhostFallbackInsideDocker bool metricStore metric.Store logger *zap.Logger tracerProvider *sdktrace.TracerProvider tracePropagators propagation.TextMapPropagator proxy ProxyFunc } var _ ApiTransportFactory = TransportFactory{} type TransportOptions struct { PreHandlers []TransportPreHandler PostHandlers []TransportPostHandler SubgraphTransportOptions *SubgraphTransportOptions Proxy ProxyFunc RetryOptions retrytransport.RetryOptions LocalhostFallbackInsideDocker bool MetricStore metric.Store Logger *zap.Logger TracerProvider *sdktrace.TracerProvider TracePropagators propagation.TextMapPropagator } func NewTransport(opts *TransportOptions) *TransportFactory { return &TransportFactory{ preHandlers: opts.PreHandlers, postHandlers: opts.PostHandlers, retryOptions: opts.RetryOptions, subgraphTransportOptions: opts.SubgraphTransportOptions, localhostFallbackInsideDocker: opts.LocalhostFallbackInsideDocker, metricStore: opts.MetricStore, logger: opts.Logger, tracerProvider: opts.TracerProvider, proxy: opts.Proxy, tracePropagators: opts.TracePropagators, } } func (t TransportFactory) RoundTripper(enableSingleFlight bool, baseTransport http.RoundTripper) http.RoundTripper { if t.subgraphTransportOptions != nil && t.subgraphTransportOptions.SubgraphMap != nil && len(t.subgraphTransportOptions.SubgraphMap) > 0 { baseTransport = NewSubgraphTransport(t.subgraphTransportOptions, baseTransport, t.logger, t.proxy) } if t.localhostFallbackInsideDocker && docker.Inside() { baseTransport = docker.NewLocalhostFallbackRoundTripper(baseTransport) } otelHttpOptions := []otelhttp.Option{ otelhttp.WithSpanNameFormatter(SpanNameFormatter), otelhttp.WithSpanOptions(otrace.WithAttributes(otel.EngineTransportAttribute)), otelhttp.WithTracerProvider(t.tracerProvider), } if t.tracePropagators != nil { otelHttpOptions = append(otelHttpOptions, otelhttp.WithPropagators(t.tracePropagators)) } traceTransport := trace.NewTransport( baseTransport, otelHttpOptions, trace.WithPreHandler(func(r *http.Request) { span := otrace.SpanFromContext(r.Context()) reqContext := getRequestContext(r.Context()) var attributes []attribute.KeyValue subgraph := reqContext.ActiveSubgraph(r) if subgraph != nil { attributes = append(attributes, otel.WgSubgraphID.String(subgraph.Id)) attributes = append(attributes, otel.WgSubgraphName.String(subgraph.Name)) } attributes = append(attributes, reqContext.telemetry.traceAttrs...) span.SetAttributes(attributes...) }), ) tp := NewCustomTransport( t.logger, traceTransport, t.retryOptions, t.metricStore, enableSingleFlight, ) tp.preHandlers = t.preHandlers tp.postHandlers = t.postHandlers tp.logger = t.logger return tp } func (t TransportFactory) DefaultHTTPProxyURL() *url.URL { return nil } // SpanNameFormatter formats the span name based on the http request func SpanNameFormatter(_ string, r *http.Request) string { requestContext := getRequestContext(r.Context()) if requestContext != nil && requestContext.operation != nil { return GetSpanName(requestContext.operation.Name(), requestContext.operation.Type()) } return fmt.Sprintf("%s %s", r.Method, r.URL.Path) } func GetSpanName(operationName string, operationType string) string { if operationName != "" { return fmt.Sprintf("%s %s", operationType, operationName) } return fmt.Sprintf("%s %s", operationType, "unnamed") }