router/internal/circuit/manager.go (157 lines of code) (raw):
package circuit
import (
"errors"
"sync"
"time"
"github.com/cep21/circuit/v4"
"github.com/cep21/circuit/v4/closers/hystrix"
"github.com/wundergraph/cosmo/router/pkg/metric"
"go.opentelemetry.io/otel/attribute"
)
// CircuitBreakerConfig defines the configuration for circuit breaker
// This decouples the circuit package from the config package
type CircuitBreakerConfig struct {
Enabled bool
ErrorThresholdPercentage int64
RequestThreshold int64
SleepWindow time.Duration
HalfOpenAttempts int64
RequiredSuccessfulAttempts int64
RollingDuration time.Duration
NumBuckets int
ExecutionTimeout time.Duration
MaxConcurrentRequests int64
}
type Manager struct {
// We maintain separate circuit breakers for each subgraph
circuits map[string]*circuit.Circuit
internalManager *circuit.Manager
isBaseConfigEnabled bool
lock sync.RWMutex
}
func NewManager(baseConfig CircuitBreakerConfig) (*Manager, error) {
circuitManager := &circuit.Manager{}
if baseConfig.Enabled {
configFunc, err := createConfiguration(baseConfig)
if err != nil {
return nil, err
}
circuitManager.DefaultCircuitProperties = []circuit.CommandPropertiesConstructor{
configFunc,
}
}
return &Manager{
circuits: make(map[string]*circuit.Circuit),
internalManager: circuitManager,
isBaseConfigEnabled: baseConfig.Enabled,
}, nil
}
func (c *Manager) GetCircuitBreaker(name string) *circuit.Circuit {
if c == nil {
return nil
}
c.lock.RLock()
defer c.lock.RUnlock()
if circuitBreaker, ok := c.circuits[name]; ok {
return circuitBreaker
}
return nil
}
func (c *Manager) AddCircuitBreaker(name string, createCircuit *circuit.Circuit) {
if c == nil {
return
}
c.lock.Lock()
defer c.lock.Unlock()
c.circuits[name] = createCircuit
}
func (c *Manager) HasCircuits() bool {
if c == nil {
return false
}
c.lock.RLock()
defer c.lock.RUnlock()
return len(c.circuits) > 0
}
type ManagerOpts struct {
SubgraphCircuitBreakers map[string]CircuitBreakerConfig
MetricStore metric.CircuitMetricStore
UseMetrics bool
BaseOtelAttributes []attribute.KeyValue
AllGroupings map[string]map[string]bool
}
func (c *Manager) Initialize(opts ManagerOpts) error {
var joinErr error
// We iterate over every grouping and create ONE instance of circuit breaker for each subgraph
// and assign the same circuit breaker instance to all subgraphs that are grouped together
// when using defaults, and initialize per config when not using defaults
for routingUrl, sgNames := range opts.AllGroupings {
defaultSgNames := make([]string, 0, len(sgNames))
customSgNames := make([]string, 0, len(sgNames))
for sgName := range sgNames {
entry, ok := opts.SubgraphCircuitBreakers[sgName]
if !ok {
defaultSgNames = append(defaultSgNames, sgName)
} else if entry.Enabled {
// This will cover the case of if a subgraph is explicitly disabled
customSgNames = append(customSgNames, sgName)
}
}
// These are the default configs, if enabled will be applied to all subgraphs
if len(defaultSgNames) > 0 && c.isBaseConfigEnabled {
configs := make([]circuit.Config, 0, 1)
if opts.UseMetrics {
configs = append(configs, metric.NewCircuitBreakerMetricsConfig(defaultSgNames, opts.MetricStore, opts.BaseOtelAttributes))
}
createCircuit, err := c.internalManager.CreateCircuit(routingUrl, configs...)
if err != nil {
joinErr = errors.Join(joinErr, err)
continue
}
for _, sgName := range defaultSgNames {
// Set the same circuit breaker instance grouped by subgraph name
c.AddCircuitBreaker(sgName, createCircuit)
}
}
// If there are any custom override configs per subgraph
if len(customSgNames) > 0 {
for _, sgName := range customSgNames {
configs := make([]circuit.Config, 0, 1)
if opts.UseMetrics {
configs = append(configs, metric.NewCircuitBreakerMetricsConfig([]string{sgName}, opts.MetricStore, opts.BaseOtelAttributes))
}
configFunc, err := createConfiguration(opts.SubgraphCircuitBreakers[sgName])
if err != nil {
return err
}
configs = append(configs, configFunc(sgName))
createCircuit, err := c.internalManager.CreateCircuit(sgName, configs...)
if err != nil {
joinErr = errors.Join(joinErr, err)
continue
}
c.AddCircuitBreaker(sgName, createCircuit)
}
}
}
return joinErr
}
func createConfiguration(opts CircuitBreakerConfig) (circuit.CommandPropertiesConstructor, error) {
// This is only applicable for tests and is blocked by the config schema
if opts.NumBuckets > 0 {
modVal := int64(opts.RollingDuration) % int64(opts.NumBuckets)
if modVal != 0 {
return nil, errors.New("rolling duration must be divisible by num buckets")
}
}
return func(name string) circuit.Config {
return circuit.Config{
General: circuit.GeneralConfig{
OpenToClosedFactory: hystrix.CloserFactory(hystrix.ConfigureCloser{
SleepWindow: opts.SleepWindow,
HalfOpenAttempts: opts.HalfOpenAttempts,
RequiredConcurrentSuccessful: opts.RequiredSuccessfulAttempts,
}),
ClosedToOpenFactory: hystrix.OpenerFactory(hystrix.ConfigureOpener{
ErrorThresholdPercentage: opts.ErrorThresholdPercentage,
RequestVolumeThreshold: opts.RequestThreshold,
RollingDuration: opts.RollingDuration,
NumBuckets: opts.NumBuckets,
}),
},
Execution: circuit.ExecutionConfig{
Timeout: opts.ExecutionTimeout,
MaxConcurrentRequests: opts.MaxConcurrentRequests,
},
}
}, nil
}