router/core/operation_processor.go (1,022 lines of code) (raw):

package core import ( "bytes" "context" "crypto/sha256" "encoding/json" "fmt" "hash" "io" "net/http" "net/url" "slices" "sync" "time" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/dgraph-io/ristretto/v2" "github.com/pkg/errors" "github.com/tidwall/sjson" fastjson "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/apollocompatibility" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization/uploads" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" "github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" "github.com/wundergraph/graphql-go-tools/v2/pkg/variablesvalidation" "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/internal/unsafebytes" "github.com/wundergraph/cosmo/router/pkg/config" ) var ( // staticOperationName is used to replace the operation name in the document when generating the operation ID // this ensures that the operation ID is the same for the same operation regardless of the operation name staticOperationName = []byte("O") ) type ParsedOperation struct { // ID represents a unique-ish ID for the operation calculated by hashing // its normalized representation ID uint64 // InternalID is the internal ID of the operation calculated by hashing // its normalized representation with the original operation name and normalized variables InternalID uint64 // Sha256Hash is the sha256 hash of the original operation query sent by the client Sha256Hash string // Type is a string representing the operation type. One of // "query", "mutation", "subscription" Type string Variables *fastjson.Object RemapVariables map[string]string // NormalizedRepresentation is the normalized representation of the operation // as a string. This is provided for modules to be able to access the // operation. Only available after the operation has been normalized. NormalizedRepresentation string Request GraphQLRequest GraphQLRequestExtensions GraphQLRequestExtensions IsPersistedOperation bool PersistedOperationCacheHit bool // NormalizationCacheHit is set to true if the request is a non-persisted operation and the normalized operation was loaded from cache NormalizationCacheHit bool } type invalidExtensionsTypeError jsonparser.ValueType func (e invalidExtensionsTypeError) Error() string { return fmt.Sprintf("invalid extensions type: %s, most be object or null", jsonparser.ValueType(e)) } func (e invalidExtensionsTypeError) Message() string { return e.Error() } func (e invalidExtensionsTypeError) StatusCode() int { return http.StatusBadRequest } func (e invalidExtensionsTypeError) ExtensionCode() string { return "" } var ( _ HttpError = invalidExtensionsTypeError(0) ) type OperationProcessorOptions struct { Executor *Executor MaxOperationSizeInBytes int64 PersistedOperationClient persistedoperation.SaveClient AutomaticPersistedOperationCacheTtl int EnablePersistedOperationsCache bool PersistedOpsNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] NormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] QueryDepthCache *ristretto.Cache[uint64, ComplexityCacheEntry] ValidationCache *ristretto.Cache[uint64, bool] OperationHashCache *ristretto.Cache[uint64, string] ParseKitPoolSize int IntrospectionEnabled bool ApolloCompatibilityFlags config.ApolloCompatibilityFlags ApolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags } // OperationProcessor provides shared resources to the parseKit and OperationKit. // It should be only instantiated once and shared across requests type OperationProcessor struct { executor *Executor maxOperationSizeInBytes int64 persistedOperationClient persistedoperation.SaveClient operationCache *OperationCache parseKits map[int]*parseKit parseKitSemaphore chan int introspectionEnabled bool parseKitOptions *parseKitOptions } // parseKit is a helper struct to parse, normalize and validate operations type parseKit struct { i int numOperations int parser *astparser.Parser doc *ast.Document keyGen *xxhash.Digest sha256Hash hash.Hash staticNormalizer *astnormalization.OperationNormalizer variablesNormalizer *astnormalization.VariablesNormalizer variablesRemapper *astnormalization.VariablesMapper printer *astprinter.Printer normalizedOperation *bytes.Buffer variablesValidator *variablesvalidation.VariablesValidator operationValidator *astvalidation.OperationValidator } type OperationCache struct { persistedOperationVariableNames map[string][]string persistedOperationVariableNamesLock *sync.RWMutex automaticPersistedOperationCacheTtl float64 persistedOperationNormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry] complexityCache *ristretto.Cache[uint64, ComplexityCacheEntry] validationCache *ristretto.Cache[uint64, bool] operationHashCache *ristretto.Cache[uint64, string] } // OperationKit provides methods to parse, normalize and validate operations. // After each step, the operation is available as a ParsedOperation. // It must be created for each request and freed after the request is done. type OperationKit struct { cache *OperationCache operationDefinitionRef int originalOperationNameRef ast.ByteSliceReference operationProcessor *OperationProcessor kit *parseKit parsedOperation *ParsedOperation introspectionEnabled bool } type GraphQLRequest struct { Query string `json:"query,omitempty"` OperationName string `json:"operationName,omitempty"` Variables json.RawMessage `json:"variables,omitempty"` Extensions json.RawMessage `json:"extensions,omitempty"` } type GraphQLRequestExtensions struct { PersistedQuery *GraphQLRequestExtensionsPersistedQuery `json:"persistedQuery"` } type GraphQLRequestExtensionsPersistedQuery struct { Version int `json:"version"` Sha256Hash string `json:"sha256Hash"` } type complexityComparison struct { field int cachedField int errorMessage string } // NewOperationKit creates a new OperationKit. The kit is used to parse, normalize and validate operations. // It allocates resources that need to be freed by calling OperationKit.Free() func NewOperationKit(processor *OperationProcessor) *OperationKit { return &OperationKit{ operationProcessor: processor, kit: processor.getKit(), operationDefinitionRef: -1, cache: processor.operationCache, parsedOperation: &ParsedOperation{}, introspectionEnabled: processor.introspectionEnabled, } } // NewIndependentOperationKit creates a new OperationKit that does not share resources with other kits. func NewIndependentOperationKit(processor *OperationProcessor) *OperationKit { return &OperationKit{ operationProcessor: processor, kit: createParseKit(0, processor.parseKitOptions), operationDefinitionRef: -1, cache: processor.operationCache, parsedOperation: &ParsedOperation{}, introspectionEnabled: processor.introspectionEnabled, } } // Free releases the resources used by the OperationKit func (o *OperationKit) Free() { o.operationProcessor.freeKit(o.kit) } // UnmarshalOperationFromURL loads the operation from the URL and unmarshal it into the ParsedOperation // It follows the GraphQL over HTTP specification for GET requests https://graphql.github.io/graphql-over-http/draft/#sec-GET // We always compact the variables and extensions to ensure that we produce easy to parse JSON for the engine func (o *OperationKit) UnmarshalOperationFromURL(url *url.URL) error { values := url.Query() query := values.Get("query") if query != "" { o.parsedOperation.Request.Query = values.Get("query") } operationName := values.Get("operationName") if operationName != "" { o.parsedOperation.Request.OperationName = operationName } variables := values.Get("variables") if variables != "" { o.parsedOperation.Request.Variables = []byte(variables) buf := bytes.NewBuffer(make([]byte, len(o.parsedOperation.Request.Variables))[:0]) err := json.Compact(buf, o.parsedOperation.Request.Variables) if err != nil { return err } } extensions := values.Get("extensions") if extensions != "" { o.parsedOperation.Request.Extensions = []byte(extensions) buf := bytes.NewBuffer(make([]byte, len(o.parsedOperation.Request.Extensions))[:0]) err := json.Compact(buf, o.parsedOperation.Request.Extensions) if err != nil { return err } } return o.unmarshalOperation() } // UnmarshalOperationFromBody loads the operation from the request body and unmarshal it into the ParsedOperation // This will load operationName, query, variables and extensions from the request body but extension and variables // will be unmarshalled as JSON.RawMessage. // We always compact the variables and extensions to ensure that we produce easy to parse JSON for the engine func (o *OperationKit) UnmarshalOperationFromBody(data []byte) error { buf := bytes.NewBuffer(make([]byte, len(data))[:0]) err := json.Compact(buf, data) if err != nil { return err } err = json.Unmarshal(buf.Bytes(), &o.parsedOperation.Request) if err != nil { return err } return o.unmarshalOperation() } // unmarshalOperation unmarshal the extensions and variables from the request body into the ParsedOperation // and does some pre-processing on the operation to ensure that the engine can handle it func (o *OperationKit) unmarshalOperation() error { var err error if o.parsedOperation.Request.Extensions != nil { var mapExtensions map[string]any err = json.Unmarshal(o.parsedOperation.Request.Extensions, &mapExtensions) if err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error parsing extensions: %s", err), statusCode: http.StatusBadRequest, } } err = json.Unmarshal(o.parsedOperation.Request.Extensions, &o.parsedOperation.GraphQLRequestExtensions) if err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error parsing extensions: %s", err), statusCode: http.StatusBadRequest, } } if o.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil { // Delete persistedQuery from extensions to avoid it being passed to the subgraphs o.parsedOperation.Request.Extensions, err = sjson.DeleteBytes(o.parsedOperation.Request.Extensions, "persistedQuery") if err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error deleting persistedQuery from extensions: %s", err), statusCode: http.StatusBadRequest, } } } } if o.parsedOperation.Request.Variables != nil { // variables must be a valid JSON object or null variables, err := fastjson.ParseBytes(o.parsedOperation.Request.Variables) if err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error parsing variables: %s", err), statusCode: http.StatusBadRequest, } } switch variables.Type() { case fastjson.TypeNull: // set variables to empty object if they are null, so we can later add exported defaults // also, other parts of the engine depend on variables being a valid JSON object o.parsedOperation.Request.Variables = []byte("{}") o.parsedOperation.Variables = fastjson.MustParseBytes(o.parsedOperation.Request.Variables).GetObject() case fastjson.TypeObject: o.parsedOperation.Variables = variables.GetObject() default: return &httpGraphqlError{ message: "variables must be an object", statusCode: http.StatusBadRequest, } } } else { // set variables to empty object if they are null, so we can later add exported defaults // also, other parts of the engine depend on variables being a valid JSON object o.parsedOperation.Request.Variables = []byte("{}") o.parsedOperation.Variables = fastjson.MustParseBytes(o.parsedOperation.Request.Variables).GetObject() } // we're doing string matching on the operation name, so we override null with empty string if o.jsonIsNull(unsafebytes.StringToBytes(o.parsedOperation.Request.OperationName)) { o.parsedOperation.Request.OperationName = "" } if o.parsedOperation.GraphQLRequestExtensions.PersistedQuery != nil && len(o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) > 0 { o.parsedOperation.IsPersistedOperation = true } return nil } func (o *OperationKit) ComputeOperationSha256() error { // Calculate a fast hash of the operation query to save the // expensive compute on the same request. We can't use the operation id at this point // because the id is generated after normalization. We want to have the hash as soon as possible for // observability reasons _, _ = o.kit.keyGen.WriteString(o.parsedOperation.Request.Query) id := o.kit.keyGen.Sum64() o.kit.keyGen.Reset() if v, ok := o.cache.operationHashCache.Get(id); ok { o.parsedOperation.Sha256Hash = v return nil } _, err := o.kit.sha256Hash.Write(unsafebytes.StringToBytes(o.parsedOperation.Request.Query)) defer o.kit.sha256Hash.Reset() if err != nil { return err } // we're using the hex representation of the sha256 hash sha256Hash := fmt.Sprintf("%x", o.kit.sha256Hash.Sum(nil)) o.cache.operationHashCache.Set(id, sha256Hash, 1) o.parsedOperation.Sha256Hash = sha256Hash return nil } // FetchPersistedOperation fetches the persisted operation from the cache or the client. If the operation is fetched from the cache it returns true. // UnmarshalOperationFromBody or UnmarshalOperationFromURL must be called before calling this method. func (o *OperationKit) FetchPersistedOperation(ctx context.Context, clientInfo *ClientInfo) (bool, bool, error) { if o.operationProcessor.persistedOperationClient == nil { return false, false, &httpGraphqlError{ message: "could not resolve persisted query, feature is not configured", statusCode: http.StatusOK, } } fromCache, includeOperationName, err := o.loadPersistedOperationFromCache(clientInfo.Name) if err != nil { return false, false, &httpGraphqlError{ statusCode: http.StatusInternalServerError, message: "error loading persisted operation from cache", } } if fromCache { if isApq, _ := o.persistedOperationCacheKeyHasTtl(clientInfo.Name, includeOperationName); isApq { // if it is an APQ request, we need to save it again to renew the TTL expiration if err = o.operationProcessor.persistedOperationClient.SaveOperation(ctx, clientInfo.Name, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, o.parsedOperation.NormalizedRepresentation); err != nil { return false, false, err } } return true, false, nil } persistedOperationData, isApq, err := o.operationProcessor.persistedOperationClient.PersistedOperation(ctx, clientInfo.Name, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) if err != nil { return false, isApq, err } else if isApq && persistedOperationData == nil && o.parsedOperation.Request.Query == "" { // If the client has APQ enabled, throw an error if the operation wasn't attached to the request return false, isApq, &persistedoperation.PersistentOperationNotFoundError{ ClientName: clientInfo.Name, Sha256Hash: o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, } } // it's important to make a copy of the persisted operation data, because it's used in the cache // we might modify it later, so we don't want to modify the cached data if persistedOperationData != nil { o.parsedOperation.Request.Query = string(persistedOperationData) } // If the operation was fetched with APQ, save it again to renew the TTL if isApq { if err = o.operationProcessor.persistedOperationClient.SaveOperation(ctx, clientInfo.Name, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, o.parsedOperation.Request.Query); err != nil { return false, isApq, err } } return false, isApq, nil } const ( schemaIntrospectionFieldName = "__schema" typeIntrospectionFieldName = "__type" ) func (o *OperationKit) isIntrospectionQuery() (result bool, err error) { var operationDefinitionRef = ast.InvalidRef var possibleOperationDefinitionRefs = make([]int, 0) for i := 0; i < len(o.kit.doc.RootNodes); i++ { if o.kit.doc.RootNodes[i].Kind == ast.NodeKindOperationDefinition { possibleOperationDefinitionRefs = append(possibleOperationDefinitionRefs, o.kit.doc.RootNodes[i].Ref) } } if len(possibleOperationDefinitionRefs) == 0 { return } else if len(possibleOperationDefinitionRefs) == 1 { operationDefinitionRef = possibleOperationDefinitionRefs[0] } else { for i := 0; i < len(possibleOperationDefinitionRefs); i++ { ref := possibleOperationDefinitionRefs[i] name := o.kit.doc.OperationDefinitionNameString(ref) if o.parsedOperation.Request.OperationName == name { operationDefinitionRef = ref break } } } if operationDefinitionRef == ast.InvalidRef { return } operationDef := o.kit.doc.OperationDefinitions[operationDefinitionRef] if operationDef.OperationType != ast.OperationTypeQuery { return } if !operationDef.HasSelections { return } selectionSet := o.kit.doc.SelectionSets[operationDef.SelectionSet] if len(selectionSet.SelectionRefs) == 0 { return } for i := 0; i < len(selectionSet.SelectionRefs); i++ { selection := o.kit.doc.Selections[selectionSet.SelectionRefs[i]] if selection.Kind != ast.SelectionKindField { continue } fieldName := o.kit.doc.FieldNameUnsafeString(selection.Ref) switch fieldName { case schemaIntrospectionFieldName, typeIntrospectionFieldName: return true, nil } } return false, nil } // Parse parses the operation, populate the document and set the operation type. // UnmarshalOperationFromBody must be called before calling this method. func (o *OperationKit) Parse() error { var ( anonymousOperationCount = 0 anonymousOperationDefinitionRef = -1 ) if len(o.parsedOperation.Request.Query) == 0 { return &httpGraphqlError{ message: "error parsing request body", statusCode: http.StatusBadRequest, } } report := &operationreport.Report{} o.kit.doc.Input.ResetInputString(o.parsedOperation.Request.Query) o.kit.parser.Parse(o.kit.doc, report) if report.HasErrors() { return &reportError{ report: report, } } if !o.introspectionEnabled { isIntrospection, err := o.isIntrospectionQuery() if err != nil { return &httpGraphqlError{ message: "could not determine if operation was an introspection query", statusCode: http.StatusOK, } } if isIntrospection { return &httpGraphqlError{ message: "GraphQL introspection is disabled by Cosmo Router, but the query contained __schema or __type. To enable introspection, set introspection_enabled: true in the Router configuration", statusCode: http.StatusOK, } } } o.kit.numOperations = 0 for i := range o.kit.doc.RootNodes { if o.kit.doc.RootNodes[i].Kind != ast.NodeKindOperationDefinition { continue } o.kit.numOperations++ ref := o.kit.doc.RootNodes[i].Ref name := string(o.kit.doc.OperationDefinitionNameBytes(ref)) if len(name) == 0 { anonymousOperationCount++ if anonymousOperationDefinitionRef == -1 { anonymousOperationDefinitionRef = ref } continue } if o.parsedOperation.Request.OperationName == "" { o.operationDefinitionRef = ref o.originalOperationNameRef = o.kit.doc.OperationDefinitions[ref].Name o.parsedOperation.Request.OperationName = name continue } if name == o.parsedOperation.Request.OperationName && o.operationDefinitionRef == -1 { o.operationDefinitionRef = ref o.originalOperationNameRef = o.kit.doc.OperationDefinitions[ref].Name } } if o.parsedOperation.Request.OperationName == "" && o.kit.numOperations > 1 { return &httpGraphqlError{ message: "operation name is required when multiple operations are defined", statusCode: http.StatusOK, } } if o.parsedOperation.Request.OperationName != "" && o.kit.numOperations != 0 && o.operationDefinitionRef == -1 { return &httpGraphqlError{ message: fmt.Sprintf("operation with name '%s' not found", o.parsedOperation.Request.OperationName), statusCode: http.StatusOK, } } if o.operationDefinitionRef == -1 { if anonymousOperationCount == 1 { o.operationDefinitionRef = anonymousOperationDefinitionRef } else if anonymousOperationCount > 1 { return &httpGraphqlError{ message: "operation name is required when multiple operations are defined", statusCode: http.StatusOK, } } else { return &httpGraphqlError{ message: fmt.Sprintf("operation with name '%s' not found", o.parsedOperation.Request.OperationName), statusCode: http.StatusOK, } } } switch o.kit.doc.OperationDefinitions[o.operationDefinitionRef].OperationType { case ast.OperationTypeQuery: o.parsedOperation.Type = "query" case ast.OperationTypeMutation: o.parsedOperation.Type = "mutation" case ast.OperationTypeSubscription: o.parsedOperation.Type = "subscription" default: return &httpGraphqlError{ message: "operation type not supported", statusCode: http.StatusOK, } } // Replace the operation name with a static name to avoid different IDs for the same operation replaceOperationName := o.kit.doc.Input.AppendInputBytes(staticOperationName) o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = replaceOperationName return nil } // NormalizeOperation normalizes the operation. After normalization the normalized representation of the operation // and variables is available. Also, the final operation ID is generated. func (o *OperationKit) NormalizeOperation(clientName string, isApq bool) (bool, error) { if o.parsedOperation.IsPersistedOperation { return o.normalizePersistedOperation(clientName, isApq) } return o.normalizeNonPersistedOperation() } func (o *OperationKit) normalizePersistedOperation(clientName string, isApq bool) (cached bool, err error) { if o.parsedOperation.NormalizedRepresentation != "" { // when dealing with APQ requests which have a TTL set, we need to renew the TTL if shouldRenew, skipIncludeNames := o.persistedOperationCacheKeyHasTtl(clientName, o.kit.numOperations > 1); shouldRenew { o.savePersistedOperationToCache(clientName, true, skipIncludeNames) } // normalized operation was loaded from cache return true, nil } skipIncludeNames := o.skipIncludeVariableNames() report := &operationreport.Report{} o.kit.doc.Input.Variables = o.parsedOperation.Request.Variables o.kit.staticNormalizer.NormalizeNamedOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, staticOperationName, report) if report.HasErrors() { return false, &reportError{ report: report, } } // Print the operation with the original operation name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = o.originalOperationNameRef err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return false, errors.WithStack(fmt.Errorf("normalizePersistedOperation failed printing operation: %w", err)) } // Set the normalized representation o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables if o.cache != nil && o.cache.persistedOperationNormalizationCache != nil { o.savePersistedOperationToCache(clientName, isApq, skipIncludeNames) } return false, nil } type NormalizationCacheEntry struct { operationID uint64 normalizedRepresentation string operationType string operationDefinitionRef int } type ComplexityCacheEntry struct { Depth int TotalFields int RootFields int RootFieldAliases int } func (o *OperationKit) normalizeNonPersistedOperation() (cached bool, err error) { skipIncludeVariableNames := o.skipIncludeVariableNames() cacheKey := o.normalizationCacheKey(skipIncludeVariableNames) if o.cache != nil && o.cache.normalizationCache != nil { entry, ok := o.cache.normalizationCache.Get(cacheKey) if ok { o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation o.parsedOperation.Type = entry.operationType o.parsedOperation.NormalizationCacheHit = true err = o.setAndParseOperationDoc() if err != nil { return false, err } return true, nil } } // normalize the operation report := &operationreport.Report{} o.kit.doc.Input.Variables = o.parsedOperation.Request.Variables o.kit.staticNormalizer.NormalizeNamedOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, staticOperationName, report) if report.HasErrors() { return false, &reportError{ report: report, } } // reset with the original variables o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables // Hash the normalized operation with the static operation name & original variables to avoid different IDs for the same operation err = o.kit.printer.Print(o.kit.doc, o.kit.keyGen) if err != nil { return false, errors.WithStack(fmt.Errorf("normalizeNonPersistedOperation (uncached) failed generating operation hash: %w", err)) } // Print the operation with the original operation name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = o.originalOperationNameRef err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return false, errors.WithStack(fmt.Errorf("normalizeNonPersistedOperation (uncached) failed printing operation: %w", err)) } // Set the normalized representation o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() if o.cache != nil && o.cache.normalizationCache != nil { entry := NormalizationCacheEntry{ operationID: o.parsedOperation.InternalID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, operationType: o.parsedOperation.Type, } o.cache.normalizationCache.Set(cacheKey, entry, 1) } return false, nil } func (o *OperationKit) setAndParseOperationDoc() error { o.kit.doc.Reset() o.kit.doc.Input.ResetInputString(o.parsedOperation.NormalizedRepresentation) o.kit.doc.Input.Variables = o.parsedOperation.Request.Variables report := &operationreport.Report{} o.kit.parser.Parse(o.kit.doc, report) if report.HasErrors() { return &reportError{ report: report, } } return nil } func (o *OperationKit) NormalizeVariables() ([]uploads.UploadPathMapping, error) { before := len(o.kit.doc.Input.Variables) + len(o.kit.doc.Input.RawBytes) report := &operationreport.Report{} uploadsMapping := o.kit.variablesNormalizer.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) if report.HasErrors() { return nil, &reportError{ report: report, } } // Assuming the user sends a multi-operation document // During normalization, we removed the unused operations from the document // This will always lead to operation definitions of a length of 1 even when multiple operations are sent if o.parsedOperation.NormalizationCacheHit { o.operationDefinitionRef = 0 } // Print the operation without the operation name to get the pure normalized form // Afterward we can calculate the operation ID that is used as a stable identifier for analytics o.kit.normalizedOperation.Reset() // store the original name of the operation nameRef := o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name staticNameRef := o.kit.doc.Input.AppendInputBytes([]byte("")) o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = staticNameRef err := o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return nil, err } // Reset the doc with the original name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = nameRef o.kit.keyGen.Reset() _, err = o.kit.keyGen.Write(o.kit.normalizedOperation.Bytes()) if err != nil { return nil, err } o.parsedOperation.ID = o.kit.keyGen.Sum64() // If the normalized form of the operation didn't change, we don't need to print it again after := len(o.kit.doc.Input.Variables) + len(o.kit.doc.Input.RawBytes) if after == before { return uploadsMapping, nil } o.kit.normalizedOperation.Reset() err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return nil, err } o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables return uploadsMapping, nil } func (o *OperationKit) RemapVariables(disabled bool) error { report := &operationreport.Report{} // even if the variables are disabled, we still need to execute rest of the method, // as it generates InternalID for the operation, which is used as planner cache key if !disabled { variablesMap := o.kit.variablesRemapper.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) if report.HasErrors() { return &reportError{ report: report, } } o.parsedOperation.RemapVariables = variablesMap } // Print the operation without the operation name to get the pure normalized form // Afterward we can calculate the operation ID that is used as a stable identifier for analytics o.kit.normalizedOperation.Reset() // store the original name of the operation nameRef := o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name staticNameRef := o.kit.doc.Input.AppendInputBytes([]byte("")) o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = staticNameRef err := o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return errors.WithStack(fmt.Errorf("RemapVariables failed generating operation hash: %w", err)) } // Reset the doc with the original name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = nameRef o.kit.keyGen.Reset() _, err = o.kit.keyGen.Write(o.kit.normalizedOperation.Bytes()) if err != nil { return err } // Generate the operation ID o.parsedOperation.InternalID = o.kit.keyGen.Sum64() o.kit.keyGen.Reset() o.kit.normalizedOperation.Reset() err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { return err } o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() return nil } func (o *OperationKit) loadPersistedOperationFromCache(clientName string) (ok bool, includeOpName bool, err error) { if o.cache == nil || o.cache.persistedOperationNormalizationCache == nil { return false, false, nil } cacheKey, ok := o.loadPersistedOperationCacheKey(clientName, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, false) if !ok { _, _ = o.cache.persistedOperationNormalizationCache.Get(0) // register cache miss return false, false, nil } entry, ok := o.cache.persistedOperationNormalizationCache.Get(cacheKey) if ok { return true, false, o.handleFoundPersistedOperationEntry(entry) } if o.parsedOperation.Request.OperationName == "" { return false, false, nil } if namedCacheKey, namedOk := o.loadPersistedOperationCacheKey(clientName, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, true); namedOk { if namedEntry, ok := o.cache.persistedOperationNormalizationCache.Get(namedCacheKey); ok { return true, true, o.handleFoundPersistedOperationEntry(namedEntry) } } return false, false, nil } func (o *OperationKit) handleFoundPersistedOperationEntry(entry NormalizationCacheEntry) error { o.parsedOperation.PersistedOperationCacheHit = true o.parsedOperation.NormalizationCacheHit = true o.parsedOperation.InternalID = entry.operationID o.parsedOperation.NormalizedRepresentation = entry.normalizedRepresentation o.parsedOperation.Type = entry.operationType // We will always only have a single operation definition in the document // Because we removed the unused operations during normalization o.operationDefinitionRef = 0 err := o.setAndParseOperationDoc() if err != nil { return err } return nil } func (o *OperationKit) jsonIsNull(variables []byte) bool { if variables == nil { return true } if len(variables) == 4 && unsafebytes.BytesToString(variables) == "null" { return true } value, err := fastjson.ParseBytes(variables) if err != nil { return false } return value.Type() == fastjson.TypeNull } func (o *OperationKit) persistedOperationCacheKeyHasTtl(clientName string, includeOperationName bool) (bool, []string) { if o.cache == nil || o.cache.persistedOperationVariableNames == nil || o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash == "" { return false, nil } o.cache.persistedOperationVariableNamesLock.RLock() variableNames, present := o.cache.persistedOperationVariableNames[o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash] o.cache.persistedOperationVariableNamesLock.RUnlock() if !present { return false, variableNames } cacheKey := o.generatePersistedOperationCacheKey(clientName, variableNames, includeOperationName) ttl, ok := o.cache.persistedOperationNormalizationCache.GetTTL(cacheKey) return ok && ttl > 0, variableNames } func (o *OperationKit) savePersistedOperationToCache(clientName string, isApq bool, skipIncludeVariableNames []string) { cacheKey := o.generatePersistedOperationCacheKey(clientName, skipIncludeVariableNames, o.kit.numOperations > 1) entry := NormalizationCacheEntry{ operationID: o.parsedOperation.InternalID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, operationType: o.parsedOperation.Type, operationDefinitionRef: o.operationDefinitionRef, } if isApq { ttl := o.cache.automaticPersistedOperationCacheTtl ttlD := time.Duration(ttl) * time.Second o.cache.persistedOperationNormalizationCache.SetWithTTL(cacheKey, entry, 1, ttlD) } else { o.cache.persistedOperationNormalizationCache.Set(cacheKey, entry, 1) } o.cache.persistedOperationVariableNamesLock.Lock() o.cache.persistedOperationVariableNames[o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash] = skipIncludeVariableNames o.cache.persistedOperationVariableNamesLock.Unlock() } func (o *OperationKit) loadPersistedOperationCacheKey(clientName, persistedQuerySha256Hash string, includeOperationName bool) (key uint64, ok bool) { o.cache.persistedOperationVariableNamesLock.RLock() variableNames := o.cache.persistedOperationVariableNames[persistedQuerySha256Hash] o.cache.persistedOperationVariableNamesLock.RUnlock() key = o.generatePersistedOperationCacheKey(clientName, variableNames, includeOperationName) return key, true } func (o *OperationKit) generatePersistedOperationCacheKey(clientName string, skipIncludeVariableNames []string, includeOperationName bool) uint64 { _, _ = o.kit.keyGen.WriteString(o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) if includeOperationName { // If there are multiple operations in the document, we need to include the operation name in the cache key _, _ = o.kit.keyGen.WriteString(o.parsedOperation.Request.OperationName) } _, _ = o.kit.keyGen.WriteString(clientName) o.writeSkipIncludeCacheKeyToKeyGen(skipIncludeVariableNames) sum := o.kit.keyGen.Sum64() o.kit.keyGen.Reset() return sum } func (o *OperationKit) normalizationCacheKey(skipIncludeVariableNames []string) uint64 { _, _ = o.kit.keyGen.WriteString(o.parsedOperation.Request.Query) _, _ = o.kit.keyGen.WriteString(o.parsedOperation.Request.OperationName) o.writeSkipIncludeCacheKeyToKeyGen(skipIncludeVariableNames) sum := o.kit.keyGen.Sum64() o.kit.keyGen.Reset() return sum } func (o *OperationKit) writeSkipIncludeCacheKeyToKeyGen(skipIncludeVariableNames []string) { for i := range skipIncludeVariableNames { value := o.parsedOperation.Variables.Get(skipIncludeVariableNames[i]) if value == nil { _, _ = o.kit.keyGen.WriteString("x") continue } switch value.Type() { case fastjson.TypeTrue: _, _ = o.kit.keyGen.WriteString("t") case fastjson.TypeFalse: _, _ = o.kit.keyGen.WriteString("f") default: _, _ = o.kit.keyGen.WriteString("x") } } } // Validate validates the operation variables. func (o *OperationKit) Validate(skipLoader bool, remapVariables map[string]string, apolloCompatibilityFlags *config.ApolloCompatibilityFlags) (cacheHit bool, err error) { if !skipLoader { // in case we're skipping the loader, it means that we won't execute the operation // this means that we don't need to validate the variables as they are not used // this is useful to return a query plan without having to provide variables err = o.kit.variablesValidator.ValidateWithRemap(o.kit.doc, o.operationProcessor.executor.ClientSchema, o.kit.doc.Input.Variables, remapVariables) if err != nil { var invalidVarErr *variablesvalidation.InvalidVariableError if errors.As(err, &invalidVarErr) { graphqlErr := &httpGraphqlError{ extensionCode: invalidVarErr.ExtensionCode, message: invalidVarErr.Error(), statusCode: http.StatusOK, } if apolloCompatibilityFlags != nil && apolloCompatibilityFlags.ReplaceValidationErrorStatus.Enabled { graphqlErr.statusCode = http.StatusBadRequest } return false, graphqlErr } return false, &httpGraphqlError{ message: err.Error(), statusCode: http.StatusOK, } } } if o.cache != nil && o.cache.validationCache != nil { var valid bool valid, cacheHit = o.cache.validationCache.Get(o.parsedOperation.InternalID) if valid { return } } report := &operationreport.Report{} o.kit.operationValidator.Validate(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) if o.cache != nil && o.cache.validationCache != nil { valid := !report.HasErrors() o.cache.validationCache.Set(o.parsedOperation.InternalID, valid, 1) } if report.HasErrors() { return cacheHit, &reportError{ report: report, } } return } // ValidateQueryComplexity validates that the query complexity is within the limits set in the configuration func (o *OperationKit) ValidateQueryComplexity(complexityLimitConfig *config.ComplexityLimits, operation, definition *ast.Document, isPersisted bool) (bool, ComplexityCacheEntry, error) { if o.cache != nil && o.cache.complexityCache != nil { if cachedComplexity, ok := o.cache.complexityCache.Get(o.parsedOperation.InternalID); ok { return ok, cachedComplexity, o.runComplexityComparisons(complexityLimitConfig, cachedComplexity, isPersisted) } } report := operationreport.Report{} globalComplexityResult, rootFieldStats := operation_complexity.CalculateOperationComplexity(operation, definition, &report) cacheResult := ComplexityCacheEntry{ Depth: globalComplexityResult.Depth, TotalFields: globalComplexityResult.NodeCount, } for _, entry := range rootFieldStats { if entry.Alias == "" { cacheResult.RootFields += 1 } else { cacheResult.RootFieldAliases += 1 } } if o.cache != nil && o.cache.complexityCache != nil { o.cache.complexityCache.Set(o.parsedOperation.InternalID, cacheResult, 1) } return false, cacheResult, o.runComplexityComparisons(complexityLimitConfig, cacheResult, isPersisted) } func (o *OperationKit) runComplexityComparisons(complexityLimitConfig *config.ComplexityLimits, cachedComplexity ComplexityCacheEntry, isPersisted bool) error { testComparisons := []complexityComparison{} if complexityLimitConfig.Depth != nil && complexityLimitConfig.Depth.ApplyLimit(isPersisted) { testComparisons = append(testComparisons, complexityComparison{complexityLimitConfig.Depth.Limit, cachedComplexity.Depth, fmt.Sprintf("The query depth %d exceeds the max query depth allowed (%d)", cachedComplexity.Depth, complexityLimitConfig.Depth.Limit)}) } if complexityLimitConfig.TotalFields != nil && complexityLimitConfig.TotalFields.ApplyLimit(isPersisted) { testComparisons = append(testComparisons, complexityComparison{complexityLimitConfig.TotalFields.Limit, cachedComplexity.TotalFields, fmt.Sprintf("The total number of fields %d exceeds the limit allowed (%d)", cachedComplexity.TotalFields, complexityLimitConfig.TotalFields.Limit)}) } if complexityLimitConfig.RootFields != nil && complexityLimitConfig.RootFields.ApplyLimit(isPersisted) { testComparisons = append(testComparisons, complexityComparison{complexityLimitConfig.RootFields.Limit, cachedComplexity.RootFields, fmt.Sprintf("The number of root fields %d exceeds the root field limit allowed (%d)", cachedComplexity.RootFields, complexityLimitConfig.RootFields.Limit)}) } if complexityLimitConfig.RootFieldAliases != nil && complexityLimitConfig.RootFieldAliases.ApplyLimit(isPersisted) { testComparisons = append(testComparisons, complexityComparison{complexityLimitConfig.RootFieldAliases.Limit, cachedComplexity.RootFieldAliases, fmt.Sprintf("The number of root field aliases %d exceeds the root field aliases limit allowed (%d)", cachedComplexity.RootFieldAliases, complexityLimitConfig.RootFieldAliases.Limit)}) } for _, comparison := range testComparisons { valid := comparison.field <= 0 || comparison.cachedField <= comparison.field if !valid { return &httpGraphqlError{ message: comparison.errorMessage, statusCode: http.StatusBadRequest, } } } return nil } var ( literalIF = []byte("if") ) func (o *OperationKit) skipIncludeVariableNames() []string { if len(o.kit.doc.Directives) == 0 { return nil } variableNames := make(map[string]struct{}, len(o.kit.doc.Directives)) for i := range o.kit.doc.Directives { name := o.kit.doc.DirectiveNameBytes(i) switch string(name) { case "skip", "include": if value, ok := o.kit.doc.DirectiveArgumentValueByName(i, literalIF); ok { if value.Kind != ast.ValueKindVariable { continue } variableName := o.kit.doc.VariableValueNameString(value.Ref) variableNames[variableName] = struct{}{} } } } names := make([]string, len(variableNames)) i := 0 for name := range variableNames { names[i] = name i++ } slices.Sort(names) return names } type parseKitOptions struct { apolloCompatibilityFlags config.ApolloCompatibilityFlags apolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags } func createParseKit(i int, options *parseKitOptions) *parseKit { return &parseKit{ i: i, parser: astparser.NewParser(), doc: ast.NewSmallDocument(), keyGen: xxhash.New(), sha256Hash: sha256.New(), staticNormalizer: astnormalization.NewWithOpts( astnormalization.WithRemoveNotMatchingOperationDefinitions(), astnormalization.WithInlineFragmentSpreads(), astnormalization.WithRemoveFragmentDefinitions(), astnormalization.WithRemoveUnusedVariables(), ), variablesNormalizer: astnormalization.NewVariablesNormalizer(), variablesRemapper: astnormalization.NewVariablesMapper(), printer: &astprinter.Printer{}, normalizedOperation: &bytes.Buffer{}, variablesValidator: variablesvalidation.NewVariablesValidator(variablesvalidation.VariablesValidatorOptions{ ApolloCompatibilityFlags: apollocompatibility.Flags{ ReplaceInvalidVarError: options.apolloCompatibilityFlags.ReplaceInvalidVarErrors.Enabled, }, ApolloRouterCompatibilityFlags: apollocompatibility.ApolloRouterFlags{ ReplaceInvalidVarError: options.apolloRouterCompatibilityFlags.ReplaceInvalidVarErrors.Enabled, }, }), operationValidator: astvalidation.DefaultOperationValidator(astvalidation.WithApolloCompatibilityFlags( apollocompatibility.Flags{ ReplaceUndefinedOpFieldError: options.apolloCompatibilityFlags.ReplaceUndefinedOpFieldErrors.Enabled, }, )), } } func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor { if opts.ParseKitPoolSize <= 0 { opts.ParseKitPoolSize = 1 } processor := &OperationProcessor{ executor: opts.Executor, maxOperationSizeInBytes: opts.MaxOperationSizeInBytes, persistedOperationClient: opts.PersistedOperationClient, parseKits: make(map[int]*parseKit, opts.ParseKitPoolSize), parseKitSemaphore: make(chan int, opts.ParseKitPoolSize), introspectionEnabled: opts.IntrospectionEnabled, parseKitOptions: &parseKitOptions{ apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, apolloRouterCompatibilityFlags: opts.ApolloRouterCompatibilityFlags, }, } for i := 0; i < opts.ParseKitPoolSize; i++ { processor.parseKitSemaphore <- i processor.parseKits[i] = createParseKit(i, processor.parseKitOptions) } if opts.NormalizationCache != nil || opts.ValidationCache != nil || opts.QueryDepthCache != nil || opts.OperationHashCache != nil || opts.EnablePersistedOperationsCache { processor.operationCache = &OperationCache{ normalizationCache: opts.NormalizationCache, validationCache: opts.ValidationCache, complexityCache: opts.QueryDepthCache, operationHashCache: opts.OperationHashCache, } } if opts.EnablePersistedOperationsCache { processor.operationCache.automaticPersistedOperationCacheTtl = float64(opts.AutomaticPersistedOperationCacheTtl) processor.operationCache.persistedOperationVariableNames = map[string][]string{} processor.operationCache.persistedOperationVariableNamesLock = &sync.RWMutex{} processor.operationCache.persistedOperationNormalizationCache = opts.PersistedOpsNormalizationCache } return processor } func (p *OperationProcessor) getKit() *parseKit { i := <-p.parseKitSemaphore return p.parseKits[i] } func (p *OperationProcessor) freeKit(kit *parseKit) { kit.keyGen.Reset() kit.doc.Reset() kit.sha256Hash.Reset() kit.normalizedOperation.Reset() // because we're re-using the kit, and we're having a static number of kits based on the number of CPUs // we're resetting the doc, parser, and buffer for the normalized operation if they grow too large (>1MB of query size) if cap(kit.doc.Input.RawBytes) > 1024*1024 { kit.doc = ast.NewSmallDocument() kit.parser = astparser.NewParser() kit.normalizedOperation = &bytes.Buffer{} } p.parseKitSemaphore <- kit.i } func (p *OperationProcessor) ReadBody(reader io.Reader, buf *bytes.Buffer) ([]byte, error) { if _, err := io.Copy(buf, reader); err != nil { // Set when http.MaxBytesReader is used before var maxBytesErr *http.MaxBytesError if errors.As(err, &maxBytesErr) { return nil, &httpGraphqlError{ message: fmt.Sprintf("request body too large, max size is %d bytes", p.maxOperationSizeInBytes), statusCode: http.StatusRequestEntityTooLarge, } } return nil, fmt.Errorf("failed to read request body: %w", err) } return buf.Bytes(), nil } // NewKit creates a new OperationKit. The kit is used to parse, normalize and // validate operations. It also validates if the operation size is within the // limit. func (p *OperationProcessor) NewKit() (*OperationKit, error) { return NewOperationKit(p), nil } // NewIndependentKit creates a new OperationKit which will not be pooled. // This is useful, e.g. for warming up the caches func (p *OperationProcessor) NewIndependentKit() (*OperationKit, error) { return NewIndependentOperationKit(p), nil }