gollm/factory.go (189 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gollm import ( "context" "errors" "fmt" "math/rand/v2" "net" "net/http" "net/url" "os" "strings" "sync" "time" "k8s.io/klog/v2" ) var globalRegistry registry type registry struct { mutex sync.Mutex providers map[string]FactoryFunc } type FactoryFunc func(ctx context.Context, uri *url.URL) (Client, error) func RegisterProvider(id string, factoryFunc FactoryFunc) error { return globalRegistry.RegisterProvider(id, factoryFunc) } func (r *registry) RegisterProvider(id string, factoryFunc FactoryFunc) error { r.mutex.Lock() defer r.mutex.Unlock() if r.providers == nil { r.providers = make(map[string]FactoryFunc) } _, exists := r.providers[id] if exists { return fmt.Errorf("provider %q is already registered", id) } r.providers[id] = factoryFunc return nil } func (r *registry) NewClient(ctx context.Context, providerID string) (Client, error) { // providerID can be just an ID, for example "gemini" instead of "gemini://" if !strings.Contains(providerID, "/") && !strings.Contains(providerID, ":") { providerID = providerID + "://" } u, err := url.Parse(providerID) if err != nil { return nil, fmt.Errorf("parsing provider id %q: %w", providerID, err) } factoryFunc := r.providers[u.Scheme] if factoryFunc == nil { return nil, fmt.Errorf("provider %q not registered", u.Scheme) } return factoryFunc(ctx, u) } // NewClient builds an Client based on the LLM_CLIENT env var or the provided providerID. ProviderID (if not empty) overrides the provider from LLM_CLIENT env var. func NewClient(ctx context.Context, providerID string) (Client, error) { if providerID == "" { s := os.Getenv("LLM_CLIENT") if s == "" { return nil, fmt.Errorf("LLM_CLIENT is not set") } providerID = s } return globalRegistry.NewClient(ctx, providerID) } // APIError represents an error returned by the LLM client. type APIError struct { StatusCode int Message string Err error } func (e *APIError) Error() string { if e.Err != nil { return fmt.Sprintf("API Error: Status=%d, Message='%s', OriginalErr=%v", e.StatusCode, e.Message, e.Err) } return fmt.Sprintf("API Error: Status=%d, Message='%s'", e.StatusCode, e.Message) } func (e *APIError) Unwrap() error { return e.Err } // IsRetryableFunc defines the signature for functions that check if an error is retryable. // TODO (droot): Adjust the signature to allow underlying client to relay the backoff // delay etc. for example, Gemini's error codes contain retryDelay information. type IsRetryableFunc func(error) bool // DefaultIsRetryableError provides a default implementation based on common HTTP codes and network errors. func DefaultIsRetryableError(err error) bool { if err == nil { return false } var apiErr *APIError if errors.As(err, &apiErr) { switch apiErr.StatusCode { case http.StatusConflict, http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: return true default: return false } } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { return true } // Add other error checks specific to LLM clients if needed // e.g., if errors.Is(err, specificLLMRateLimitError) { return true } return false } // RetryConfig holds the configuration for the retry mechanism (same as before) type RetryConfig struct { MaxAttempts int InitialBackoff time.Duration MaxBackoff time.Duration BackoffFactor float64 Jitter bool } // DefaultRetryConfig provides sensible defaults (same as before) var DefaultRetryConfig = RetryConfig{ MaxAttempts: 5, InitialBackoff: 200 * time.Millisecond, // Slightly increased default MaxBackoff: 10 * time.Second, BackoffFactor: 2.0, Jitter: true, } // Retry executes the provided operation with retries, returning the result and error. // It's now generic to handle any return type T. func Retry[T any]( ctx context.Context, config RetryConfig, isRetryable IsRetryableFunc, operation func(ctx context.Context) (T, error), ) (T, error) { var lastErr error var zero T // Zero value of the return type T log := klog.FromContext(ctx) backoff := config.InitialBackoff for attempt := 1; attempt <= config.MaxAttempts; attempt++ { // log.Printf("Executing operation, attempt %d of %d", attempt, config.MaxAttempts) // Optional verbose log result, err := operation(ctx) if err == nil { // Success return result, nil } lastErr = err // Store the last error encountered // Check if context was cancelled *after* the operation select { case <-ctx.Done(): log.Info("Context cancelled after attempt %d failed.", "attempt", attempt) return zero, ctx.Err() // Return context error preferentially default: // Context not cancelled, proceed with error checking } if !isRetryable(lastErr) { log.Info("Attempt failed with non-retryable error", "attempt", attempt, "error", lastErr) return zero, lastErr // Return the non-retryable error immediately } log.Info("Attempt failed with retryable error", "attempt", attempt, "error", lastErr) if attempt == config.MaxAttempts { // Max attempts reached break } // Calculate wait time waitTime := backoff if config.Jitter { waitTime += time.Duration(rand.Float64() * float64(backoff) / 2) } log.Info("Waiting before next attempt", "waitTime", waitTime, "attempt", attempt+1, "maxAttempts", config.MaxAttempts) // Wait or react to context cancellation select { case <-time.After(waitTime): // Wait finished case <-ctx.Done(): log.Info("Context cancelled while waiting for retry after attempt %d.", "attempt", attempt) return zero, ctx.Err() } // Increase backoff backoff = time.Duration(float64(backoff) * config.BackoffFactor) if backoff > config.MaxBackoff { backoff = config.MaxBackoff } } // If the loop finished, it means all attempts failed errFinal := fmt.Errorf("operation failed after %d attempts: %w", config.MaxAttempts, lastErr) return zero, errFinal } // retryChat is a generic decorator that adds retry logic to any Chat implementation. type retryChat[C Chat] struct { underlying Chat // The actual client implementation being wrapped config RetryConfig isRetryable IsRetryableFunc } // NewRetryChat creates a new Chat that wraps the given underlying client // with retry logic using the provided configuration. // It returns the Chat interface type, hiding the generic implementation detail. func NewRetryChat[C Chat]( underlying C, config RetryConfig, ) Chat { return &retryChat[C]{ underlying: underlying, config: config, } } // Embed implements the Client interface for the retryClient decorator. func (rc *retryChat[C]) Send(ctx context.Context, contents ...any) (ChatResponse, error) { // Define the operation operation := func(ctx context.Context) (ChatResponse, error) { return rc.underlying.Send(ctx, contents...) } // Execute with retry return Retry[ChatResponse](ctx, rc.config, rc.underlying.IsRetryableError, operation) } // Embed implements the Client interface for the retryClient decorator. func (rc *retryChat[C]) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) { // TODO: Retry logic return rc.underlying.SendStreaming(ctx, contents...) } func (rc *retryChat[C]) SetFunctionDefinitions(functionDefinitions []*FunctionDefinition) error { return rc.underlying.SetFunctionDefinitions(functionDefinitions) } func (rc *retryChat[C]) IsRetryableError(err error) bool { return rc.underlying.IsRetryableError(err) }