router/pkg/plan_generator/plan_generator.go (175 lines of code) (raw):
package plan_generator
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/wundergraph/cosmo/router/core"
"go.uber.org/zap"
)
const ReportFileName = "report.json"
type QueryPlanConfig struct {
ExecutionConfig string
SourceDir string
OutDir string
Concurrency int
Filter string
Timeout string
OutputFiles bool
OutputReport bool
FailOnPlanError bool
FailFast bool
LogLevel string
Logger *zap.Logger
MaxDataSourceCollectorsConcurrency uint
}
type QueryPlanResults struct {
Plans []QueryPlanResult `json:"plans,omitempty"`
Error string `json:"error,omitempty"`
}
type QueryPlanResult struct {
FileName string `json:"file_name,omitempty"`
Plan string `json:"plan,omitempty"`
Error string `json:"error,omitempty"`
}
func PlanGenerator(ctx context.Context, cfg QueryPlanConfig) error {
if cfg.Concurrency == 0 {
cfg.Concurrency = runtime.GOMAXPROCS(0)
}
queriesPath, err := filepath.Abs(cfg.SourceDir)
if err != nil {
return fmt.Errorf("failed to get absolute path for queries: %v", err)
}
outPath, err := filepath.Abs(cfg.OutDir)
if err != nil {
return fmt.Errorf("failed to get absolute path for output: %v", err)
}
if err := os.MkdirAll(outPath, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %v", err)
}
executionConfigPath, err := filepath.Abs(cfg.ExecutionConfig)
if err != nil {
return fmt.Errorf("failed to get absolute path for execution config: %v", err)
}
var filter []string
if cfg.Filter != "" {
filterContent, err := os.ReadFile(cfg.Filter)
if err != nil {
return fmt.Errorf("failed to read filter file: %v", err)
}
filter = strings.Split(string(filterContent), "\n")
}
queries, err := os.ReadDir(queriesPath)
if err != nil {
return fmt.Errorf("failed to read queries directory: %v", err)
}
queriesQueue := make(chan os.DirEntry, len(queries))
for _, queryFile := range queries {
queriesQueue <- queryFile
}
close(queriesQueue)
var results []QueryPlanResult
var resultsMux sync.Mutex
duration, parseErr := time.ParseDuration(cfg.Timeout)
if parseErr != nil {
return fmt.Errorf("failed to parse timeout: %v", parseErr)
}
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
ctxError, cancelError := context.WithCancelCause(ctx)
defer cancelError(nil)
pg, err := core.NewPlanGenerator(executionConfigPath, cfg.Logger, cfg.MaxDataSourceCollectorsConcurrency)
if err != nil {
return fmt.Errorf("failed to create plan generator: %v", err)
}
var planError atomic.Bool
wg := sync.WaitGroup{}
wg.Add(cfg.Concurrency)
for i := 0; i < cfg.Concurrency; i++ {
go func(i int) {
defer wg.Done()
planner, err := pg.GetPlanner()
if err != nil {
cancelError(fmt.Errorf("failed to get planner: %v", err))
}
for {
select {
case <-ctxError.Done():
return
case queryFile, ok := <-queriesQueue:
if !ok {
return
}
if !slices.Contains([]string{".graphql", ".gql", ".graphqls"}, filepath.Ext(queryFile.Name())) {
continue
}
if len(filter) > 0 && !slices.Contains(filter, queryFile.Name()) {
continue
}
queryFilePath := filepath.Join(queriesPath, queryFile.Name())
outContent, err := planner.PlanOperation(queryFilePath)
res := QueryPlanResult{
FileName: queryFile.Name(),
Plan: outContent,
}
if err != nil {
res.Error = err.Error()
outContent = fmt.Sprintf("Error: %v", err)
planError.Store(true)
if cfg.FailFast {
cancel()
}
}
if cfg.OutputFiles {
outFileName := filepath.Join(outPath, queryFile.Name())
err = os.WriteFile(outFileName, []byte(outContent), 0644)
if err != nil {
cancelError(fmt.Errorf("failed to write file: %v", err))
}
}
resultsMux.Lock()
results = append(results, res)
resultsMux.Unlock()
}
}
}(i)
}
wg.Wait()
if cfg.OutputReport {
reportFilePath := filepath.Join(outPath, ReportFileName)
reportFile, err := os.Create(reportFilePath)
if err != nil {
cancel()
return fmt.Errorf("failed to create results file: %v", err)
}
defer reportFile.Close()
slices.SortFunc(results, func(a, b QueryPlanResult) int {
return strings.Compare(a.FileName, b.FileName)
})
resultData := QueryPlanResults{
Plans: results,
}
if ctxError.Err() != nil {
resultData.Error = context.Cause(ctxError).Error()
}
data, jsonErr := json.Marshal(resultData)
if jsonErr != nil {
return fmt.Errorf("failed to marshal result: %v", jsonErr)
}
_, writeErr := reportFile.WriteString(fmt.Sprintf("%s\n", data))
if writeErr != nil {
return fmt.Errorf("failed to write result: %v", writeErr)
}
}
if cfg.FailOnPlanError && planError.Load() {
return fmt.Errorf("some queries failed to generate plan")
}
return context.Cause(ctxError)
}