graphql/executor/executor.go (192 lines of code) (raw):

package executor import ( "context" "github.com/vektah/gqlparser/v2/ast" "github.com/vektah/gqlparser/v2/gqlerror" "github.com/vektah/gqlparser/v2/parser" "github.com/vektah/gqlparser/v2/validator" "github.com/vektah/gqlparser/v2/validator/rules" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/errcode" ) const parserTokenNoLimit = 0 // Executor executes graphql queries against a schema. type Executor struct { es graphql.ExecutableSchema extensions []graphql.HandlerExtension ext extensions errorPresenter graphql.ErrorPresenterFunc recoverFunc graphql.RecoverFunc queryCache graphql.Cache[*ast.QueryDocument] parserTokenLimit int disableSuggestion bool } var _ graphql.GraphExecutor = &Executor{} // New creates a new Executor with the given schema, and a default error and // recovery callbacks, and no query cache or extensions. func New(es graphql.ExecutableSchema) *Executor { e := &Executor{ es: es, errorPresenter: graphql.DefaultErrorPresenter, recoverFunc: graphql.DefaultRecover, queryCache: graphql.NoCache[*ast.QueryDocument]{}, ext: processExtensions(nil), parserTokenLimit: parserTokenNoLimit, } return e } func (e *Executor) CreateOperationContext( ctx context.Context, params *graphql.RawParams, ) (*graphql.OperationContext, gqlerror.List) { opCtx := &graphql.OperationContext{ DisableIntrospection: true, RecoverFunc: e.recoverFunc, ResolverMiddleware: e.ext.fieldMiddleware, RootResolverMiddleware: e.ext.rootFieldMiddleware, Stats: graphql.Stats{ Read: params.ReadTime, OperationStart: graphql.GetStartTime(ctx), }, } ctx = graphql.WithOperationContext(ctx, opCtx) for _, p := range e.ext.operationParameterMutators { if err := p.MutateOperationParameters(ctx, params); err != nil { return opCtx, gqlerror.List{err} } } opCtx.RawQuery = params.Query opCtx.OperationName = params.OperationName opCtx.Headers = params.Headers var listErr gqlerror.List opCtx.Doc, listErr = e.parseQuery(ctx, &opCtx.Stats, params.Query) if len(listErr) != 0 { return opCtx, listErr } opCtx.Operation = opCtx.Doc.Operations.ForName(params.OperationName) if opCtx.Operation == nil { err := gqlerror.Errorf("operation %s not found", params.OperationName) errcode.Set(err, errcode.ValidationFailed) return opCtx, gqlerror.List{err} } var err error opCtx.Variables, err = validator.VariableValues(e.es.Schema(), opCtx.Operation, params.Variables) if err != nil { gqlErr, ok := err.(*gqlerror.Error) if ok { errcode.Set(gqlErr, errcode.ValidationFailed) return opCtx, gqlerror.List{gqlErr} } } opCtx.Stats.Validation.End = graphql.Now() for _, p := range e.ext.operationContextMutators { if err := p.MutateOperationContext(ctx, opCtx); err != nil { return opCtx, gqlerror.List{err} } } return opCtx, nil } func (e *Executor) DispatchOperation( ctx context.Context, opCtx *graphql.OperationContext, ) (graphql.ResponseHandler, context.Context) { ctx = graphql.WithOperationContext(ctx, opCtx) var innerCtx context.Context res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler { innerCtx = ctx tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) responses := e.es.Exec(tmpResponseContext) if errs := graphql.GetErrors(tmpResponseContext); errs != nil { return graphql.OneShot(&graphql.Response{Errors: errs}) } return func(ctx context.Context) *graphql.Response { ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { resp := responses(ctx) if resp == nil { return nil } resp.Errors = append(resp.Errors, graphql.GetErrors(ctx)...) resp.Extensions = graphql.GetExtensions(ctx) return resp }) if resp == nil { return nil } return resp } }) return res, innerCtx } func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graphql.Response { ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) for _, gErr := range list { graphql.AddError(ctx, gErr) } resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { resp := &graphql.Response{ Errors: graphql.GetErrors(ctx), } resp.Extensions = graphql.GetExtensions(ctx) return resp }) return resp } func (e *Executor) PresentRecoveredError(ctx context.Context, err any) error { return e.errorPresenter(ctx, e.recoverFunc(ctx, err)) } func (e *Executor) SetQueryCache(cache graphql.Cache[*ast.QueryDocument]) { e.queryCache = cache } func (e *Executor) SetErrorPresenter(f graphql.ErrorPresenterFunc) { e.errorPresenter = f } func (e *Executor) SetRecoverFunc(f graphql.RecoverFunc) { e.recoverFunc = f } func (e *Executor) SetParserTokenLimit(limit int) { e.parserTokenLimit = limit } func (e *Executor) SetDisableSuggestion(value bool) { e.disableSuggestion = value } // parseQuery decodes the incoming query and validates it, pulling from cache if present. // // NOTE: This should NOT look at variables, they will change per request. It should only parse and // validate // the raw query string. func (e *Executor) parseQuery( ctx context.Context, stats *graphql.Stats, query string, ) (*ast.QueryDocument, gqlerror.List) { stats.Parsing.Start = graphql.Now() if doc, ok := e.queryCache.Get(ctx, query); ok { now := graphql.Now() stats.Parsing.End = now stats.Validation.Start = now return doc, nil } doc, err := parser.ParseQueryWithTokenLimit(&ast.Source{Input: query}, e.parserTokenLimit) if err != nil { gqlErr, ok := err.(*gqlerror.Error) if ok { errcode.Set(gqlErr, errcode.ParseFailed) return nil, gqlerror.List{gqlErr} } } stats.Parsing.End = graphql.Now() stats.Validation.Start = graphql.Now() if len(doc.Operations) == 0 { err = gqlerror.Errorf("no operation provided") gqlErr, _ := err.(*gqlerror.Error) errcode.Set(err, errcode.ValidationFailed) return nil, gqlerror.List{gqlErr} } // swap out the FieldsOnCorrectType rule with one that doesn't provide suggestions if e.disableSuggestion { validator.RemoveRule("FieldsOnCorrectType") rule := rules.FieldsOnCorrectTypeRuleWithoutSuggestions validator.AddRule(rule.Name, rule.RuleFunc) } listErr := validator.Validate(e.es.Schema(), doc) if len(listErr) != 0 { for _, e := range listErr { errcode.Set(e, errcode.ValidationFailed) } return nil, listErr } e.queryCache.Add(ctx, query, doc) return doc, nil }