router/core/cache_warmup.go (304 lines of code) (raw):

package core import ( "context" "errors" "time" "go.uber.org/ratelimit" "go.uber.org/zap" "google.golang.org/protobuf/encoding/protojson" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/pkg/config" ) type CacheWarmupItem struct { Request GraphQLRequest Client *ClientInfo } type CacheWarmupSource interface { LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.Operation, error) } type CacheWarmupProcessor interface { ProcessOperation(ctx context.Context, item *nodev1.Operation) (*CacheWarmupOperationPlanResult, error) } type CacheWarmupConfig struct { Log *zap.Logger Source CacheWarmupSource Workers int ItemsPerSecond int Timeout time.Duration Processor CacheWarmupProcessor AfterOperation func(item *CacheWarmupOperationPlanResult) } func WarmupCaches(ctx context.Context, cfg *CacheWarmupConfig) (err error) { w := &cacheWarmup{ log: cfg.Log.With(zap.String("component", "cache_warmup")), source: cfg.Source, workers: cfg.Workers, itemsPerSecond: cfg.ItemsPerSecond, timeout: cfg.Timeout, processor: cfg.Processor, afterOperation: cfg.AfterOperation, } if cfg.Workers < 1 { w.workers = 4 } if cfg.ItemsPerSecond < 1 { w.itemsPerSecond = 0 } if cfg.Timeout <= 0 { w.timeout = time.Second * 30 } w.log.Info("Warmup started", zap.Int("workers", cfg.Workers), zap.Int("items_per_second", cfg.ItemsPerSecond), zap.Duration("timeout", cfg.Timeout), ) start := time.Now() completed, err := w.run(ctx) if err != nil { if errors.Is(err, context.DeadlineExceeded) { w.log.Error("Warmup timeout", zap.Error(err), zap.Int("processed_items", completed), zap.String("tip", "Consider to increase the timeout, increase the number of workers, increase the items per second limit, or reduce the number of items to process"), ) return err } w.log.Error("Warmup error", zap.Error(err), zap.Int("processed_items", completed), ) return err } w.log.Info("Warmup completed", zap.Int("processed_items", completed), zap.Duration("duration", time.Since(start)), ) return nil } type cacheWarmup struct { log *zap.Logger source CacheWarmupSource workers int itemsPerSecond int timeout time.Duration processor CacheWarmupProcessor afterOperation func(item *CacheWarmupOperationPlanResult) } func (w *cacheWarmup) run(ctx context.Context) (int, error) { ctx, cancel := context.WithTimeout(ctx, w.timeout) defer cancel() items, err := w.source.LoadItems(ctx, w.log) if err != nil { return 0, err } if len(items) == 0 { w.log.Debug("No items to process") return 0, nil } w.log.Info("Starting processing", zap.Int("items", len(items)), ) defaultClientInfo := &nodev1.ClientInfo{} done := ctx.Done() index := make(chan int, len(items)) defer close(index) itemCompleted := make(chan struct{}) for i, item := range items { if item.Client == nil { item.Client = defaultClientInfo } index <- i } var ( rl ratelimit.Limiter ) if w.itemsPerSecond > 0 { rl = ratelimit.New(w.itemsPerSecond) } else { rl = ratelimit.NewUnlimited() } for i := 0; i < w.workers; i++ { go func(i int) { for { select { case <-done: return case idx, ok := <-index: if !ok { return } rl.Take() item := items[idx] res, err := w.processor.ProcessOperation(ctx, item) if err != nil { w.log.Warn("Failed to process operation, skipping", zap.Error(err), zap.String("client_name", item.Client.Name), zap.String("client_version", item.Client.Version), zap.String("query", item.Request.Query), zap.String("operation_name", item.Request.OperationName), ) } if err == nil && w.afterOperation != nil { w.afterOperation(res) } select { case <-done: return case itemCompleted <- struct{}{}: } } } }(i) } for i := 0; i < len(items); i++ { processed := i + 1 select { case <-done: return processed, ctx.Err() case <-itemCompleted: if processed%100 == 0 { w.log.Info("Processing completed", zap.Int("processed_items", processed), ) } } } return len(items), nil } type CacheWarmupPlanningProcessorOptions struct { OperationProcessor *OperationProcessor OperationPlanner *OperationPlanner ComplexityLimits *config.ComplexityLimits RouterSchema *ast.Document TrackSchemaUsage bool DisableVariablesRemapping bool } func NewCacheWarmupPlanningProcessor(options *CacheWarmupPlanningProcessorOptions) *CacheWarmupPlanningProcessor { return &CacheWarmupPlanningProcessor{ operationProcessor: options.OperationProcessor, operationPlanner: options.OperationPlanner, complexityLimits: options.ComplexityLimits, routerSchema: options.RouterSchema, trackSchemaUsage: options.TrackSchemaUsage, disableVariablesRemapping: options.DisableVariablesRemapping, } } type CacheWarmupOperationPlanResult struct { OperationHash string OperationName string OperationType string ClientName string ClientVersion string PlanningTime time.Duration } type CacheWarmupPlanningProcessor struct { operationProcessor *OperationProcessor operationPlanner *OperationPlanner complexityLimits *config.ComplexityLimits routerSchema *ast.Document trackSchemaUsage bool disableVariablesRemapping bool } func (c *CacheWarmupPlanningProcessor) ProcessOperation(ctx context.Context, operation *nodev1.Operation) (*CacheWarmupOperationPlanResult, error) { var ( isAPQ bool ) k, err := c.operationProcessor.NewIndependentKit() if err != nil { return nil, err } var s []byte if operation.Request.GetExtensions() != nil { s, err = protojson.Marshal(operation.Request.GetExtensions()) if err != nil { return nil, err } } item := &CacheWarmupItem{ Request: GraphQLRequest{ Query: operation.Request.GetQuery(), OperationName: operation.Request.GetOperationName(), Extensions: s, }, Client: &ClientInfo{ Name: operation.GetClient().GetName(), Version: operation.GetClient().GetVersion(), }, } k.parsedOperation.Request = item.Request err = k.unmarshalOperation() if err != nil { return nil, err } err = k.ComputeOperationSha256() if err != nil { return nil, err } if k.parsedOperation.IsPersistedOperation && k.parsedOperation.Request.Query == "" { _, isAPQ, err = k.FetchPersistedOperation(ctx, item.Client) if err != nil { return nil, err } } err = k.Parse() if err != nil { return nil, err } _, err = k.NormalizeOperation(item.Client.Name, isAPQ) if err != nil { return nil, err } _, err = k.NormalizeVariables() if err != nil { return nil, err } err = k.RemapVariables(c.disableVariablesRemapping) if err != nil { return nil, err } // NOTE: we do not validate query complexity here, because queries come from analytics, so they should be valid _, err = k.Validate(true, k.parsedOperation.RemapVariables, nil) if err != nil { return nil, err } planOptions := PlanOptions{ ClientInfo: item.Client, TraceOptions: resolve.TraceOptions{ Enable: false, }, ExecutionOptions: resolve.ExecutionOptions{ SkipLoader: true, IncludeQueryPlanInResponse: false, SendHeartbeat: false, }, TrackSchemaUsageInfo: c.trackSchemaUsage, } opContext := &operationContext{ clientInfo: item.Client, name: k.parsedOperation.Request.OperationName, opType: k.parsedOperation.Type, hash: k.parsedOperation.ID, content: k.parsedOperation.NormalizedRepresentation, internalHash: k.parsedOperation.InternalID, } opContext.variables, err = astjson.ParseBytes(k.parsedOperation.Request.Variables) if err != nil { return nil, err } planningStart := time.Now() err = c.operationPlanner.plan(opContext, planOptions) if err != nil { return nil, err } return &CacheWarmupOperationPlanResult{ OperationHash: k.parsedOperation.IDString(), OperationName: k.parsedOperation.Request.OperationName, OperationType: k.parsedOperation.Type, ClientName: item.Client.Name, ClientVersion: item.Client.Version, PlanningTime: time.Since(planningStart), }, nil }