graphqlmetrics/pkg/batchprocessor/batchprocessor.go (124 lines of code) (raw):

package batchprocessor import ( "context" "errors" "sync" "time" ) // Options defines configuration options for the BatchProcessor. type Options[T any] struct { MaxQueueSize int // Size of the internal queue CostFunc func([]T) int // Function to calculate the cost of a batch CostThreshold int // Threshold at which the batch should be dispatched Dispatcher func(context.Context, []T) // Function to process the batch Interval time.Duration // Interval at which the batch should be dispatched if the cost threshold is not met MaxWorkers int // Number of worker goroutines } // BatchProcessor is a generic batch processing module. type BatchProcessor[T any] struct { queue chan T batch []T costFunction func([]T) int dispatcherFunc func(context.Context, []T) interval time.Duration doneChan chan struct{} costThreshold int ctx context.Context cancel context.CancelFunc dispatchChan chan []T workerCount int wg sync.WaitGroup } // New creates a new BatchProcessor with the provided options. func New[T any](opts Options[T]) *BatchProcessor[T] { ctx, cancel := context.WithCancel(context.Background()) if opts.MaxWorkers <= 0 { opts.MaxWorkers = 1 // Ensure at least one worker } bp := &BatchProcessor[T]{ queue: make(chan T, opts.MaxQueueSize), batch: make([]T, 0), costFunction: opts.CostFunc, costThreshold: opts.CostThreshold, dispatcherFunc: opts.Dispatcher, interval: opts.Interval, doneChan: make(chan struct{}), ctx: ctx, cancel: cancel, dispatchChan: make(chan []T), workerCount: opts.MaxWorkers, } // Start the batch manager goroutine go bp.runBatchManager() // Start worker goroutines bp.wg.Add(bp.workerCount) for i := 0; i < bp.workerCount; i++ { go bp.runWorker() } return bp } // Push adds an item to the queue. Returns an error if the processor is stopped. func (bp *BatchProcessor[T]) Push(item T) error { select { case bp.queue <- item: return nil case <-bp.doneChan: // Processor stopped when doneChan is closed return errors.New("batch processor stopped") } } // StopAndWait stops the processor and waits until all items are processed or the context is done. func (bp *BatchProcessor[T]) StopAndWait(ctx context.Context) error { close(bp.doneChan) // Signal the manager to stop // Wait for worker goroutines to finish done := make(chan struct{}) go func() { bp.wg.Wait() close(done) }() select { case <-done: // All workers have finished return nil case <-ctx.Done(): // Context is canceled; cancel the context for dispatchers bp.cancel() if errors.Is(ctx.Err(), context.DeadlineExceeded) { return nil } else { return ctx.Err() } } } func (bp *BatchProcessor[T]) runBatchManager() { ticker := time.NewTicker(bp.interval) defer ticker.Stop() defer close(bp.queue) // Stop the queue defer close(bp.dispatchChan) // Stop the workers after draining / close the queue for { select { case item := <-bp.queue: bp.batch = append(bp.batch, item) cost := bp.costFunction(bp.batch) if cost >= bp.costThreshold { bp.dispatch() ticker.Reset(bp.interval) // Reset the timer after dispatching } case <-ticker.C: if len(bp.batch) > 0 { bp.dispatch() } case <-bp.doneChan: // Queue closed, process any remaining items if len(bp.batch) > 0 { bp.dispatch() } return case <-bp.ctx.Done(): // Context canceled, exit batch manager return } } } func (bp *BatchProcessor[T]) dispatch() { // Create a copy of the batch to avoid data races batchCopy := make([]T, len(bp.batch)) copy(batchCopy, bp.batch) // Reset the batch bp.batch = bp.batch[:0] // Send the batch to the dispatch channel select { case bp.dispatchChan <- batchCopy: case <-bp.ctx.Done(): } } func (bp *BatchProcessor[T]) runWorker() { defer bp.wg.Done() for batch := range bp.dispatchChan { // Process the batch with the context bp.dispatcherFunc(bp.ctx, batch) } }