router/core/batch.go (197 lines of code) (raw):

package core import ( "bufio" "bytes" "context" "errors" "fmt" "github.com/cespare/xxhash/v2" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/goccy/go-json" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/logging" ctrace "github.com/wundergraph/cosmo/router/pkg/trace" "github.com/wundergraph/graphql-go-tools/v2/pkg/graphqlerrors" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "io" "net/http" ) type BatchedOperationId struct{} const defaultBufioReaderSize = 4096 const ( ExtensionCodeBatchSizeExceeded = "BATCH_LIMIT_EXCEEDED" ExtensionCodeBatchSubscriptionsUnsupported = "BATCHING_SUBSCRIPTION_UNSUPPORTED" ) type HandlerOpts struct { MaxEntriesPerBatch int MaxRoutines int HandlerSent http.Handler Tracer trace.Tracer ClientHeader config.ClientHeader BaseOtelAttributes []attribute.KeyValue RouterConfigVersion string Digest *xxhash.Digest OmitExtensions bool Logger *zap.Logger } func Handler(handlerOpts HandlerOpts) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { requestLogger := handlerOpts.Logger.With(logging.WithRequestID(middleware.GetReqID(r.Context()))) err := processBatchedRequest(w, r, handlerOpts, requestLogger) if err != nil { processBatchError(w, r, err, requestLogger) return } } return http.HandlerFunc(fn) } func processBatchedRequest(w http.ResponseWriter, r *http.Request, handlerOpts HandlerOpts, requestLogger *zap.Logger) error { firstChar, bufReader, err := getFirstNonWhitespaceChar(r.Body, defaultBufioReaderSize) if err != nil { requestLogger.Error("failed to read request", zap.Error(err)) return &httpGraphqlError{ message: "failed to read request", statusCode: http.StatusOK, } } // When the first non whitespace character is not // an array start assume it's a non batched request if firstChar == nil || *firstChar != '[' { // if firstChar is nil we have downstream handle it // which is the current behaviour r.Body = io.NopCloser(bufReader) handlerOpts.HandlerSent.ServeHTTP(w, r) return nil } bodyBytes, err := io.ReadAll(bufReader) if err != nil { requestLogger.Error("failed to read request body", zap.Error(err)) return &httpGraphqlError{ message: "failed to read request body", statusCode: http.StatusOK, } } var batchOperations []json.RawMessage if err = json.Unmarshal(bodyBytes, &batchOperations); err != nil { // If there is an error, it's likely a malformed json array // as the start character is "[" requestLogger.Error("failed to read request body", zap.Error(err)) return &httpGraphqlError{ message: "failed to read request body", statusCode: http.StatusOK, } } batchOperationsLength := len(batchOperations) ctrace.AddBatchTracing(r, bodyBytes, handlerOpts.ClientHeader, batchOperationsLength, handlerOpts.BaseOtelAttributes, handlerOpts.RouterConfigVersion, handlerOpts.Digest, ) // When a max batch limit has been specified if batchOperationsLength > handlerOpts.MaxEntriesPerBatch { requestLogger.Error("max batch size has been exceeded") maxError := &httpGraphqlError{ message: "Invalid GraphQL request", statusCode: http.StatusOK, } if !handlerOpts.OmitExtensions { maxError.extensionCode = ExtensionCodeBatchSizeExceeded } return maxError } // We have a batched request. responses := make([]json.RawMessage, batchOperationsLength) sem := make(chan struct{}, handlerOpts.MaxRoutines) // Process each operation in parallel. for i, singleOp := range batchOperations { sem <- struct{}{} go func() { defer func() { <-sem }() batchOperationStr := fmt.Sprintf("batch-operation-%d", i) spanCtx, span := handlerOpts.Tracer.Start(r.Context(), batchOperationStr) ctx := context.WithValue(spanCtx, BatchedOperationId{}, batchOperationStr) defer span.End() // Create a new request for the single operation. rCopy := r.Clone(ctx) // Reset the route context to avoid sharing mutable state. rCopy = rCopy.WithContext(context.WithValue(rCopy.Context(), chi.RouteCtxKey, chi.NewRouteContext())) rCopy.Body = io.NopCloser(bytes.NewBuffer(singleOp)) // Create a ResponseWriter that captures the output. rw := newBufferingResponseWriter() // Execute the mux handler for this operation. handlerOpts.HandlerSent.ServeHTTP(rw, rCopy) // Store the response (assuming the response is valid JSON). responses[i] = rw.Body.Bytes() }() } // Wait for all operations to be completed by blocking. for n := handlerOpts.MaxRoutines; n > 0; n-- { sem <- struct{}{} } // Drain and close the semaphore. for len(sem) > 0 { <-sem } close(sem) // Write out the batched response as a JSON array. w.Header().Set("Content-Type", "application/json; charset=utf-8") if err := json.NewEncoder(w).Encode(responses); err != nil { return &httpGraphqlError{ message: "failed to encode batched response", statusCode: http.StatusInternalServerError, } } return nil } func processBatchError(w http.ResponseWriter, r *http.Request, err error, requestLogger *zap.Logger) { ctrace.AttachErrToSpanFromContext(r.Context(), err) requestError := graphqlerrors.RequestError{ Message: err.Error(), } statusCode := http.StatusOK var httpGqlError *httpGraphqlError if errors.As(err, &httpGqlError) { statusCode = httpGqlError.statusCode if httpGqlError.extensionCode != "" { requestError.Extensions = &graphqlerrors.Extensions{ Code: httpGqlError.extensionCode, } } } writeRequestErrors(r, w, statusCode, []graphqlerrors.RequestError{requestError}, requestLogger) } func getFirstNonWhitespaceChar(r io.Reader, readerSize int) (*byte, *bufio.Reader, error) { bufReader := bufio.NewReaderSize(r, readerSize) for { peeked, err := bufReader.Peek(1) if err != nil { if err == io.EOF { return nil, bufReader, nil } return nil, nil, err } if len(peeked) == 0 { return nil, bufReader, nil } peekByte := peeked[0] switch peekByte { // we check the characters based on this RFC https://datatracker.ietf.org/doc/html/rfc8259 // and also the array decode function in goccy/go-json (which is the library we used to decode) case ' ', '\n', '\t', '\r': bufReader.ReadByte() continue default: return &peekByte, bufReader, nil } } } type bufferingResponseWriter struct { HeaderMap http.Header Body *bytes.Buffer Status int } func newBufferingResponseWriter() *bufferingResponseWriter { return &bufferingResponseWriter{ HeaderMap: make(http.Header), Body: &bytes.Buffer{}, Status: http.StatusOK, } } func (brw *bufferingResponseWriter) Header() http.Header { return brw.HeaderMap } func (brw *bufferingResponseWriter) Write(b []byte) (int, error) { return brw.Body.Write(b) } func (brw *bufferingResponseWriter) WriteHeader(statusCode int) { brw.Status = statusCode }