registry/datastore/db.go (1,188 lines of code) (raw):
//go:generate mockgen -package mocks -destination mocks/db.go . Handler,Transactor,LoadBalancer,Connector,DNSResolver
package datastore
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"net"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/log"
"github.com/docker/distribution/registry/datastore/metrics"
"github.com/docker/distribution/registry/datastore/models"
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5/tracelog"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/errortracking"
"gitlab.com/gitlab-org/labkit/metrics/sqlmetrics"
)
const (
driverName = "pgx"
defaultReplicaCheckInterval = 1 * time.Minute
HostTypePrimary = "primary"
HostTypeReplica = "replica"
HostTypeUnknown = "unknown"
// upToDateReplicaTimeout establishes the maximum amount of time we're willing to wait for an up-to-date database
// replica to be identified during load balancing. If a replica is not identified within this threshold, then it's
// likely that there is a performance degradation going on, in which case we want to gracefully fall back to the
// primary database to avoid further processing delays. The current 100ms value is a starting point/educated guess
// that matches the one used in GitLab Rails (https://gitlab.com/gitlab-org/gitlab/-/merge_requests/159633).
upToDateReplicaTimeout = 100 * time.Millisecond
// ReplicaResolveTimeout sets a global limit on wait time for resolving replicas in load balancing, covering both DNS
// lookups and connection attempts for all identified replicas.
ReplicaResolveTimeout = 2 * time.Second
// InitReplicaResolveTimeout is a stricter limit used only during startup in NewDBLoadBalancer, taking precedence over
// ReplicaResolveTimeout. A quick failure here prevents startup delays, allowing asynchronous retries in
// StartReplicaChecking. While these timeouts are currently similar, they remain separate to allow independent tuning.
InitReplicaResolveTimeout = 1 * time.Second
// minLivenessProbeInterval is the minimum time between replica host liveness probes during load balancing.
minLivenessProbeInterval = 1 * time.Second
// minResolveReplicasInterval is the default minimum time between replicas resolution calls during load balancing.
minResolveReplicasInterval = 10 * time.Second
// livenessProbeTimeout is the maximum time for a replica liveness probe to run.
livenessProbeTimeout = 100 * time.Millisecond
// replicaLagCheckTimeout is the default timeout for checking replica lag.
replicaLagCheckTimeout = 100 * time.Millisecond
// MaxReplicaLagTime is the default maximum replication lag time
MaxReplicaLagTime = 30 * time.Second
// MaxReplicaLagBytes is the default maximum replication lag in bytes. This matches the Rails default, see
// https://gitlab.com/gitlab-org/gitlab/blob/5c68653ce8e982e255277551becb3270a92f5e9e/lib/gitlab/database/load_balancing/configuration.rb#L48-48
MaxReplicaLagBytes = 8 * 1024 * 1024
)
// Queryer is the common interface to execute queries on a database.
type Queryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
// Handler represents a database connection handler.
type Handler interface {
Queryer
Stats() sql.DBStats
Close() error
BeginTx(ctx context.Context, opts *sql.TxOptions) (Transactor, error)
}
// Transactor represents a database transaction.
type Transactor interface {
Queryer
Commit() error
Rollback() error
}
// QueryErrorProcessor defines the interface for handling database query errors.
type QueryErrorProcessor interface {
ProcessQueryError(ctx context.Context, db *DB, query string, err error)
}
// DB implements Handler.
type DB struct {
*sql.DB
DSN *DSN
errorProcessor QueryErrorProcessor
}
// BeginTx wraps sql.Tx from the innner sql.DB within a datastore.Tx.
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Transactor, error) {
tx, err := db.DB.BeginTx(ctx, opts)
return &Tx{tx, db, db.errorProcessor}, err
}
// Begin wraps sql.Tx from the inner sql.DB within a datastore.Tx.
func (db *DB) Begin() (Transactor, error) {
return db.BeginTx(context.Background(), nil)
}
// Address returns the database host network address.
func (db *DB) Address() string {
if db.DSN == nil {
return ""
}
return db.DSN.Address()
}
// QueryContext wraps the underlying QueryContext.
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err != nil && db.errorProcessor != nil {
db.errorProcessor.ProcessQueryError(ctx, db, query, err)
}
return rows, err
}
// ExecContext wraps the underlying ExecContext.
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
res, err := db.DB.ExecContext(ctx, query, args...)
if err != nil && db.errorProcessor != nil {
db.errorProcessor.ProcessQueryError(ctx, db, query, err)
}
return res, err
}
// Row is a wrapper around sql.Row that allows us to intercept errors during the rows' Scan.
type Row struct {
*sql.Row
db *DB
errorProcessor QueryErrorProcessor
query string
ctx context.Context
}
// QueryRowContext wraps the underlying QueryRowContext with our custom Row.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
row := db.DB.QueryRowContext(ctx, query, args...)
return &Row{
Row: row,
db: db,
errorProcessor: db.errorProcessor,
query: query,
ctx: ctx,
}
}
// Scan implements the sql.Row.Scan method and processes errors.
func (r *Row) Scan(dest ...any) error {
err := r.Row.Scan(dest...)
if err != nil && r.errorProcessor != nil {
r.errorProcessor.ProcessQueryError(r.ctx, r.db, r.query, err)
}
return err
}
// Tx implements Transactor.
type Tx struct {
*sql.Tx
db *DB
errorProcessor QueryErrorProcessor
}
// QueryContext wraps the underlying QueryContext.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil && tx.errorProcessor != nil {
tx.errorProcessor.ProcessQueryError(ctx, tx.db, query, err)
}
return rows, err
}
// QueryRowContext wraps the underlying Tx.QueryRowContext with our custom Row.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
row := tx.Tx.QueryRowContext(ctx, query, args...)
return &Row{
Row: row,
db: tx.db,
errorProcessor: tx.errorProcessor,
query: query,
ctx: ctx,
}
}
// ExecContext wraps the underlying ExecContext.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
res, err := tx.Tx.ExecContext(ctx, query, args...)
if err != nil && tx.errorProcessor != nil {
tx.errorProcessor.ProcessQueryError(ctx, tx.db, query, err)
}
return res, err
}
// DSN represents the Data Source Name parameters for a DB connection.
type DSN struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
SSLCert string
SSLKey string
SSLRootCert string
ConnectTimeout time.Duration
}
// String builds the string representation of a DSN.
func (dsn *DSN) String() string {
var params []string
port := ""
if dsn.Port > 0 {
port = strconv.Itoa(dsn.Port)
}
connectTimeout := ""
if dsn.ConnectTimeout > 0 {
connectTimeout = fmt.Sprintf("%.0f", dsn.ConnectTimeout.Seconds())
}
for _, param := range []struct{ k, v string }{
{"host", dsn.Host},
{"port", port},
{"user", dsn.User},
{"password", dsn.Password},
{"dbname", dsn.DBName},
{"sslmode", dsn.SSLMode},
{"sslcert", dsn.SSLCert},
{"sslkey", dsn.SSLKey},
{"sslrootcert", dsn.SSLRootCert},
{"connect_timeout", connectTimeout},
} {
if len(param.v) == 0 {
continue
}
param.v = strings.ReplaceAll(param.v, "'", `\'`)
param.v = strings.ReplaceAll(param.v, " ", `\ `)
params = append(params, param.k+"="+param.v)
}
return strings.Join(params, " ")
}
// Address returns the host:port segment of a DSN.
func (dsn *DSN) Address() string {
return net.JoinHostPort(dsn.Host, strconv.Itoa(dsn.Port))
}
type opts struct {
logger *logrus.Entry
logLevel tracelog.LogLevel
pool *PoolConfig
preferSimpleProtocol bool
loadBalancing *LoadBalancingConfig
metricsEnabled bool
promRegisterer prometheus.Registerer
}
type PoolConfig struct {
MaxIdle int
MaxOpen int
MaxLifetime time.Duration
MaxIdleTime time.Duration
}
// LoadBalancingConfig represents the database load balancing configuration.
type LoadBalancingConfig struct {
active bool
hosts []string
resolver DNSResolver
connector Connector
replicaCheckInterval time.Duration
lsnStore RepositoryCache
minResolveReplicasInterval time.Duration
}
// Option is used to configure the database connections.
type Option func(*opts)
// WithLogger configures the logger for the database connection driver.
func WithLogger(l *logrus.Entry) Option {
return func(opts *opts) {
opts.logger = l
}
}
// WithLogLevel configures the logger level for the database connection driver.
func WithLogLevel(l configuration.Loglevel) Option {
var lvl tracelog.LogLevel
switch l {
case configuration.LogLevelTrace:
lvl = tracelog.LogLevelTrace
case configuration.LogLevelDebug:
lvl = tracelog.LogLevelDebug
case configuration.LogLevelInfo:
lvl = tracelog.LogLevelInfo
case configuration.LogLevelWarn:
lvl = tracelog.LogLevelWarn
default:
lvl = tracelog.LogLevelError
}
return func(opts *opts) {
opts.logLevel = lvl
}
}
// WithPoolConfig configures the settings for the database connection pool.
func WithPoolConfig(c *PoolConfig) Option {
return func(opts *opts) {
opts.pool = c
}
}
// WithPoolMaxIdle configures the maximum number of idle pool connections.
func WithPoolMaxIdle(c int) Option {
return func(opts *opts) {
if opts.pool == nil {
opts.pool = &PoolConfig{}
}
opts.pool.MaxIdle = c
}
}
// WithPoolMaxOpen configures the maximum number of open pool connections.
func WithPoolMaxOpen(c int) Option {
return func(opts *opts) {
if opts.pool == nil {
opts.pool = &PoolConfig{}
}
opts.pool.MaxOpen = c
}
}
// WithPreparedStatements configures the settings to allow the database
// driver to use prepared statements.
func WithPreparedStatements(b bool) Option {
return func(opts *opts) {
// Registry configuration uses opposite semantics as pgx for prepared statements.
opts.preferSimpleProtocol = !b
}
}
func applyOptions(input []Option) opts {
l := logrus.New()
l.SetOutput(io.Discard)
config := opts{
logger: logrus.NewEntry(l),
pool: &PoolConfig{},
loadBalancing: &LoadBalancingConfig{
connector: NewConnector(),
replicaCheckInterval: defaultReplicaCheckInterval,
},
promRegisterer: prometheus.DefaultRegisterer,
}
for _, v := range input {
v(&config)
}
return config
}
type logger struct {
*logrus.Entry
}
// used to minify SQL statements on log entries by removing multiple spaces, tabs and new lines.
var logMinifyPattern = regexp.MustCompile(`\s+|\t+|\n+`)
// Log implements the tracelog.Logger interface.
func (l *logger) Log(_ context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
// silence if debug level is not enabled, unless it's a warn or error
if !l.Logger.IsLevelEnabled(logrus.DebugLevel) && level != tracelog.LogLevelWarn && level != tracelog.LogLevelError {
return
}
var configuredLogger *logrus.Entry
if data != nil {
// minify SQL statement, if any
if _, ok := data["sql"]; ok {
raw := fmt.Sprintf("%v", data["sql"])
data["sql"] = logMinifyPattern.ReplaceAllString(raw, " ")
}
// use milliseconds for query duration
if _, ok := data["time"]; ok {
raw := fmt.Sprintf("%v", data["time"])
d, err := time.ParseDuration(raw)
if err == nil { // this should never happen, but lets make sure to avoid panics and missing log entries
data["duration_ms"] = d.Milliseconds()
delete(data, "time")
}
}
// convert known keys to snake_case notation for consistency
if _, ok := data["rowCount"]; ok {
data["row_count"] = data["rowCount"]
delete(data, "rowCount")
}
configuredLogger = l.WithFields(data)
} else {
configuredLogger = l.Entry
}
switch level {
case tracelog.LogLevelTrace:
configuredLogger.Trace(msg)
case tracelog.LogLevelDebug:
configuredLogger.Debug(msg)
case tracelog.LogLevelInfo:
configuredLogger.Info(msg)
case tracelog.LogLevelWarn:
configuredLogger.Warn(msg)
case tracelog.LogLevelError:
configuredLogger.Error(msg)
default:
// this should never happen, but if it does, something went wrong and we need to notice it
configuredLogger.WithField("invalid_log_level", level).Error(msg)
}
}
// Connector is an interface for opening database connections. This enabled low-level testing for how connections are
// established during load balancing.
type Connector interface {
Open(ctx context.Context, dsn *DSN, opts ...Option) (*DB, error)
}
// sqlConnector is the default implementation of Connector using sql.Open.
type sqlConnector struct{}
// NewConnector creates a new sqlConnector.
func NewConnector() Connector {
return &sqlConnector{}
}
// Open opens a new database connection with the given DSN and options.
func (*sqlConnector) Open(ctx context.Context, dsn *DSN, opts ...Option) (*DB, error) {
config := applyOptions(opts)
pgxConfig, err := pgx.ParseConfig(dsn.String())
if err != nil {
return nil, fmt.Errorf("parsing connection string failed: %w", err)
}
pgxConfig.Tracer = &tracelog.TraceLog{
Logger: &logger{config.logger},
LogLevel: config.logLevel,
}
if config.preferSimpleProtocol {
// TODO: there are more query execution modes that we may want to consider in the future
// https://pkg.go.dev/github.com/jackc/pgx/v5#QueryExecMode
pgxConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
}
connStr := stdlib.RegisterConnConfig(pgxConfig)
db, err := sql.Open(driverName, connStr)
if err != nil {
return nil, fmt.Errorf("open connection handle failed: %w", err)
}
db.SetMaxOpenConns(config.pool.MaxOpen)
db.SetMaxIdleConns(config.pool.MaxIdle)
db.SetConnMaxLifetime(config.pool.MaxLifetime)
db.SetConnMaxIdleTime(config.pool.MaxIdleTime)
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("verification failed: %w", err)
}
return &DB{DB: db, DSN: dsn}, nil
}
// LoadBalancer represents a database load balancer.
type LoadBalancer interface {
Primary() *DB
Replica(context.Context) *DB
UpToDateReplica(context.Context, *models.Repository) *DB
Replicas() []*DB
Close() error
RecordLSN(context.Context, *models.Repository) error
StartPoolRefresh(context.Context) error
StartLagCheck(context.Context) error
TypeOf(*DB) string
}
// DBLoadBalancer manages connections to a primary database and multiple replicas.
type DBLoadBalancer struct {
active bool
primary *DB
replicas []*DB
lsnCache RepositoryCache
connector Connector
resolver DNSResolver
fixedHosts []string
// replicaIndex and replicaMutex are used to implement a round-robin selection of replicas.
replicaIndex int
replicaMutex sync.Mutex
replicaOpenOpts []Option
replicaCheckInterval time.Duration
// primaryDSN is stored separately to ensure we can derive replicas DSNs, even if the initial connection to the
// primary database fails. This is necessary as DB.DSN is only set after successfully establishing a connection.
primaryDSN *DSN
metricsEnabled bool
promRegisterer prometheus.Registerer
replicaPromCollectors map[string]prometheus.Collector
// For controlling replicas liveness probing
livenessProber *LivenessProber
// For controlling concurrent replicas pool resolution
throttledPoolResolver *ThrottledPoolResolver
// For tracking replication lag
lagTracker LagTracker
}
// WithFixedHosts configures the list of static hosts to use for read replicas during database load balancing.
func WithFixedHosts(hosts []string) Option {
return func(opts *opts) {
opts.loadBalancing.hosts = hosts
opts.loadBalancing.active = true
}
}
// WithServiceDiscovery enables and configures service discovery for read replicas during database load balancing.
func WithServiceDiscovery(resolver DNSResolver) Option {
return func(opts *opts) {
opts.loadBalancing.resolver = resolver
opts.loadBalancing.active = true
}
}
// WithConnector allows specifying a custom database Connector implementation to be used to establish connections,
// otherwise sql.Open is used.
func WithConnector(connector Connector) Option {
return func(opts *opts) {
opts.loadBalancing.connector = connector
}
}
// WithReplicaCheckInterval configures a custom refresh interval for the replica list when using service discovery.
// Defaults to 1 minute.
func WithReplicaCheckInterval(interval time.Duration) Option {
return func(opts *opts) {
opts.loadBalancing.replicaCheckInterval = interval
}
}
// WithLSNCache allows providing a RepositoryCache implementation to be used for recording WAL insert Log Sequence
// Numbers (LSNs) that are used to enable primary sticking during database load balancing.
func WithLSNCache(cache RepositoryCache) Option {
return func(opts *opts) {
opts.loadBalancing.lsnStore = cache
}
}
// WithMinResolveReplicasInterval configures the minimum time between resolve replicas calls. This prevents excessive
// replica resolution operations during periods of connection instability.
func WithMinResolveReplicasInterval(interval time.Duration) Option {
return func(opts *opts) {
opts.loadBalancing.minResolveReplicasInterval = interval
}
}
// WithMetricsCollection enables metrics collection.
func WithMetricsCollection() Option {
return func(opts *opts) {
opts.metricsEnabled = true
}
}
// WithPrometheusRegisterer allows specifying a custom Prometheus Registerer for metrics registration.
func WithPrometheusRegisterer(r prometheus.Registerer) Option {
return func(opts *opts) {
opts.promRegisterer = r
}
}
// DNSResolver is an interface for DNS resolution operations. This enabled low-level testing for how connections are
// established during load balancing.
type DNSResolver interface {
// LookupSRV looks up SRV records.
LookupSRV(ctx context.Context) ([]*net.SRV, error)
// LookupHost looks up IP addresses for a given host.
LookupHost(ctx context.Context, host string) ([]string, error)
}
// dnsResolver is the default implementation of DNSResolver using net.Resolver.
type dnsResolver struct {
resolver *net.Resolver
record string
}
// LookupSRV performs an SRV record lookup.
func (r *dnsResolver) LookupSRV(ctx context.Context) ([]*net.SRV, error) {
report := metrics.SRVLookup()
_, addrs, err := r.resolver.LookupSRV(ctx, "", "", r.record)
report(err)
return addrs, err
}
// LookupHost performs an IP address lookup for the given host.
func (r *dnsResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
report := metrics.HostLookup()
addrs, err := r.resolver.LookupHost(ctx, host)
report(err)
return addrs, err
}
// NewDNSResolver creates a new dnsResolver for the specified nameserver, port, and record.
func NewDNSResolver(nameserver string, port int, record string) DNSResolver {
dialer := &net.Dialer{}
return &dnsResolver{
resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", nameserver, port))
},
},
record: record,
}
}
func resolveHosts(ctx context.Context, resolver DNSResolver) ([]*net.TCPAddr, error) {
srvs, err := resolver.LookupSRV(ctx)
if err != nil {
return nil, fmt.Errorf("error resolving DNS SRV record: %w", err)
}
var result *multierror.Error
var addrs []*net.TCPAddr
for _, srv := range srvs {
// TODO: consider allowing partial successes where only a subset of replicas is reachable
ips, err := resolver.LookupHost(ctx, srv.Target)
if err != nil {
result = multierror.Append(result, fmt.Errorf("error resolving host %q address: %v", srv.Target, err))
continue
}
for _, ip := range ips {
addr := &net.TCPAddr{
IP: net.ParseIP(ip),
Port: int(srv.Port),
}
addrs = append(addrs, addr)
}
}
if result.ErrorOrNil() != nil {
return nil, result
}
return addrs, nil
}
// logger returns a log.Logger decorated with a key/value pair that uniquely identifies all entries as being emitted
// by the database load balancer component. Instead of relying on a fixed log.Logger instance, this method allows
// retrieving and extending a base logger embedded in the input context (if any) to preserve relevant key/value
// pairs introduced upstream (such as a correlation ID, present when calling from the API handlers).
func (*DBLoadBalancer) logger(ctx context.Context) log.Logger {
return log.GetLogger(log.WithContext(ctx)).WithFields(log.Fields{
"component": "registry.datastore.DBLoadBalancer",
})
}
func (lb *DBLoadBalancer) metricsCollector(db *DB, hostType string) *sqlmetrics.DBStatsCollector {
return sqlmetrics.NewDBStatsCollector(
lb.primaryDSN.DBName,
db,
sqlmetrics.WithExtraLabels(map[string]string{
"host_type": hostType,
"host_addr": db.Address(),
}),
)
}
// LivenessProber manages liveness probes for database hosts.
type LivenessProber struct {
sync.Mutex
inProgress map[string]time.Time // Maps host address to probe start time
minInterval time.Duration // Minimum time between probes for the same host
timeout time.Duration // Maximum time for a probe to run
onUnhealthy func(context.Context, *DB) // Callback for unhealthy hosts
}
// LivenessProberOption configures a LivenessProber.
type LivenessProberOption func(*LivenessProber)
// WithMinProbeInterval sets the minimum interval between probes for the same host.
func WithMinProbeInterval(interval time.Duration) LivenessProberOption {
return func(p *LivenessProber) {
p.minInterval = interval
}
}
// WithProbeTimeout sets the timeout for a probe operation.
func WithProbeTimeout(timeout time.Duration) LivenessProberOption {
return func(p *LivenessProber) {
p.timeout = timeout
}
}
// WithUnhealthyCallback sets the callback to be invoked when a host is determined to be unhealthy.
func WithUnhealthyCallback(callback func(context.Context, *DB)) LivenessProberOption {
return func(p *LivenessProber) {
p.onUnhealthy = callback
}
}
// NewLivenessProber creates a new LivenessProber with the given options.
func NewLivenessProber(opts ...LivenessProberOption) *LivenessProber {
prober := &LivenessProber{
inProgress: make(map[string]time.Time),
minInterval: minLivenessProbeInterval,
timeout: livenessProbeTimeout,
}
for _, opt := range opts {
opt(prober)
}
return prober
}
// BeginCheck checks if a probe is allowed for the given host. It returns false if another probe for this host is in
// progress or if a probe was completed too recently (within minInterval). If the probe is allowed, it
// marks the host as being probed and returns true.
func (p *LivenessProber) BeginCheck(hostAddr string) bool {
p.Lock()
defer p.Unlock()
lastProbe, exists := p.inProgress[hostAddr]
if exists && time.Since(lastProbe) < p.minInterval {
return false
}
p.inProgress[hostAddr] = time.Now()
return true
}
// EndCheck marks a probe as completed by removing its entry from the in-progress tracking map,
// allowing future probes for this host to proceed (subject to timing constraints).
func (p *LivenessProber) EndCheck(hostAddr string) {
p.Lock()
defer p.Unlock()
delete(p.inProgress, hostAddr)
}
// Probe performs a health check on a database host. The probe rate is limited to prevent excessive
// concurrent probing of the same host. If the probe fails, the onUnhealthy callback is invoked.
func (p *LivenessProber) Probe(ctx context.Context, db *DB) {
addr := db.Address()
l := log.GetLogger(log.WithContext(ctx)).WithFields(log.Fields{"db_host_addr": addr})
// Check if this host is already being probed and mark it if not
if !p.BeginCheck(addr) {
l.Info("skipping liveness probe, already in progress or too recent")
return
}
// When we're done with this function, remove the in-progress marker
defer p.EndCheck(addr)
// Perform a lightweight health check with a short timeout
probeCtx, cancel := context.WithTimeout(ctx, p.timeout)
defer cancel()
l.Info("performing liveness probe")
start := time.Now()
err := db.PingContext(probeCtx)
duration := time.Since(start).Seconds()
if err == nil {
l.WithFields(log.Fields{"duration_s": duration}).Info("host passed liveness probe")
return
}
// If we get here, the liveness probe failed
l.WithFields(log.Fields{"duration_s": duration}).WithError(err).Warn("host failed liveness probe; invoking callback")
if p.onUnhealthy != nil {
p.onUnhealthy(ctx, db)
}
}
// ThrottledPoolResolver manages resolution of database replicas with throttling to prevent excessive operations.
type ThrottledPoolResolver struct {
sync.Mutex
inProgress bool
lastComplete time.Time
minInterval time.Duration
resolveFn func(context.Context) error
}
// ThrottledPoolResolverOption configures a ThrottledPoolResolver.
type ThrottledPoolResolverOption func(*ThrottledPoolResolver)
// WithMinInterval sets the minimum interval between replica resolutions.
func WithMinInterval(interval time.Duration) ThrottledPoolResolverOption {
return func(r *ThrottledPoolResolver) {
r.minInterval = interval
}
}
// WithResolveFunction sets the function that performs the actual resolution.
func WithResolveFunction(fn func(context.Context) error) ThrottledPoolResolverOption {
return func(r *ThrottledPoolResolver) {
r.resolveFn = fn
}
}
// NewThrottledPoolResolver creates a new ThrottledPoolResolver with the given options.
func NewThrottledPoolResolver(opts ...ThrottledPoolResolverOption) *ThrottledPoolResolver {
resolver := &ThrottledPoolResolver{
minInterval: minResolveReplicasInterval,
}
for _, opt := range opts {
opt(resolver)
}
return resolver
}
// Begin checks if a replica pool resolution operation is currently allowed. It returns false if another resolution is
// already in progress or if one was completed too recently (within minInterval). If resolution is allowed, it marks
// the operation as in progress and returns true.
func (r *ThrottledPoolResolver) Begin() bool {
r.Lock()
defer r.Unlock()
if r.inProgress || (time.Since(r.lastComplete) < r.minInterval) {
return false
}
r.inProgress = true
return true
}
// Complete marks a replicas resolution operation as complete and updates the timestamp of the last completed operation
// to enforce the minimum interval between operations.
func (r *ThrottledPoolResolver) Complete() {
r.Lock()
defer r.Unlock()
r.inProgress = false
r.lastComplete = time.Now()
}
// Resolve triggers a resolution of the replica pool if allowed by throttling constraints.
// It returns true if the resolution was performed, false if it was skipped due to throttling.
func (r *ThrottledPoolResolver) Resolve(ctx context.Context) bool {
l := log.GetLogger(log.WithContext(ctx))
// Check if another resolution is already in progress or happened too recently
if !r.Begin() {
l.Info("skipping replica pool resolution, already in progress or too recent")
return false
}
// When we're done, mark the operation as complete
defer r.Complete()
// Perform the resolution
l.Info("resolving replicas")
if err := r.resolveFn(ctx); err != nil {
l.WithError(err).Error("failed to resolve replicas")
} else {
l.Info("successfully resolved replicas")
}
return true
}
// isConnectivityError determines if the given error is related to database connectivity issues. It checks for specific
// PostgreSQL error classes that may indicate severe connection problems:
// - Class 08: Connection Exception (connection failures, protocol violations)
// - Class 57: Operator Intervention (server shutdown, crash)
// - Class 53: Insufficient Resources (self-explanatory)
// - Class 58: System Error (errors external to PostgreSQL itself)
// - Class XX: Internal Error (self-explanatory)
// Note: We intentionally don't check for raw network timeout as those can lead to a high rate of false positives. The
// list of watched errors should be fine-tuned as we experience connectivity issues in production.
// See https://www.postgresql.org/docs/current/errcodes-appendix.html
func isConnectivityError(err error) bool {
if err == nil {
return false
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgerrcode.IsConnectionException(pgErr.Code) ||
pgerrcode.IsOperatorIntervention(pgErr.Code) ||
pgerrcode.IsInsufficientResources(pgErr.Code) ||
pgerrcode.IsSystemError(pgErr.Code) ||
pgerrcode.IsInternalError(pgErr.Code) {
return true
}
}
return false
}
// ProcessQueryError handles database connectivity errors during query executions by triggering appropriate responses:
// - For primary database errors, it initiates a full replica resolution to refresh the pool
// - For replica database errors, it initiates an individual liveness probe that may retire the (faulty?) replica
func (lb *DBLoadBalancer) ProcessQueryError(ctx context.Context, db *DB, query string, err error) {
if err != nil && isConnectivityError(err) {
hostType := lb.TypeOf(db)
l := lb.logger(ctx).WithError(err).WithFields(log.Fields{
"db_host_type": hostType,
"db_host_addr": db.Address(),
"query": query,
})
switch hostType {
case HostTypePrimary:
// If the primary connection fails (a failover event is possible), proactively refresh all replicas
l.Warn("primary database connection error during query execution; initiating replica resolution")
go lb.throttledPoolResolver.Resolve(context.WithoutCancel(ctx)) // detach from outer context to avoid external cancellation
case HostTypeReplica:
// For a replica, run a liveness probe and retire it if necessary
l.Warn("replica database connection error during query execution; initiating liveness probe")
go lb.livenessProber.Probe(context.WithoutCancel(ctx), db)
default:
// This is not supposed to happen, log and report
err := fmt.Errorf("unknown database host type: %w", err)
l.Error(err)
errortracking.Capture(err, errortracking.WithContext(ctx), errortracking.WithStackTrace())
}
}
}
// unregisterReplicaMetricsCollector removes the Prometheus metrics collector associated with a database replica.
// This should be called when a replica is retired from the pool to ensure metrics are properly cleaned up.
// If metrics collection is disabled or if no collector exists for the given replica, this is a no-op.
func (lb *DBLoadBalancer) unregisterReplicaMetricsCollector(r *DB) {
if lb.metricsEnabled {
if collector, exists := lb.replicaPromCollectors[r.Address()]; exists {
lb.promRegisterer.Unregister(collector)
delete(lb.replicaPromCollectors, r.Address())
}
}
}
// ResolveReplicas initializes or updates the list of available replicas atomically by resolving the provided hosts
// either through service discovery or using a fixed hosts list. As result, the load balancer replica pool will be
// up-to-date. Replicas for which we failed to establish a connection to are not included in the pool.
func (lb *DBLoadBalancer) ResolveReplicas(ctx context.Context) error {
lb.replicaMutex.Lock()
defer lb.replicaMutex.Unlock()
ctx, cancel := context.WithTimeout(ctx, ReplicaResolveTimeout)
defer cancel()
var result *multierror.Error
l := lb.logger(ctx)
// Resolve replica DSNs
var resolvedDSNs []DSN
if lb.resolver != nil {
l.Info("resolving replicas with service discovery")
addrs, err := resolveHosts(ctx, lb.resolver)
if err != nil {
return fmt.Errorf("failed to resolve replica hosts: %w", err)
}
for _, addr := range addrs {
dsn := *lb.primaryDSN
dsn.Host = addr.IP.String()
dsn.Port = addr.Port
resolvedDSNs = append(resolvedDSNs, dsn)
}
} else if len(lb.fixedHosts) > 0 {
l.Info("resolving replicas with fixed hosts list")
for _, host := range lb.fixedHosts {
dsn := *lb.primaryDSN
dsn.Host = host
resolvedDSNs = append(resolvedDSNs, dsn)
}
}
// Open connections for _added_ replicas
var outputReplicas []*DB
var added, removed []string
for i := range resolvedDSNs {
var err error
dsn := &resolvedDSNs[i]
l = l.WithFields(logrus.Fields{"db_replica_addr": dsn.Address()})
r := dbByAddress(lb.replicas, dsn.Address())
if r != nil {
// check if connection to existing replica is still usable
if err := r.PingContext(ctx); err != nil {
l.WithError(err).Warn("replica is known but connection is stale, attempting to reconnect")
r, err = lb.connector.Open(ctx, dsn, lb.replicaOpenOpts...)
if err != nil {
result = multierror.Append(result, fmt.Errorf("reopening replica %q database connection: %w", dsn.Address(), err))
continue
}
} else {
l.Info("replica is known and healthy, reusing connection")
}
} else {
l.Info("replica is new, opening connection")
if r, err = lb.connector.Open(ctx, dsn, lb.replicaOpenOpts...); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to open replica %q database connection: %w", dsn.Address(), err))
continue
}
added = append(added, r.Address())
metrics.ReplicaAdded()
// Register metrics collector for the added replica
if lb.metricsEnabled {
collector := lb.metricsCollector(r, HostTypeReplica)
// Unlike the primary host metrics collector, replica collectors wil be registered in the background
// whenever the pool changes. We don't want to cause a panic here, so we'll rely on prometheus.Register
// instead of prometheus.MustRegister and gracefully handle an error by logging and reporting it.
if err := lb.promRegisterer.Register(collector); err != nil {
l.WithError(err).WithFields(log.Fields{"db_replica_addr": r.Address()}).
Error("failed to register collector for database replica metrics")
errortracking.Capture(err, errortracking.WithContext(ctx), errortracking.WithStackTrace())
}
lb.replicaPromCollectors[r.Address()] = collector
}
}
r.errorProcessor = lb
outputReplicas = append(outputReplicas, r)
}
// Identify removed replicas
for _, r := range lb.replicas {
if dbByAddress(outputReplicas, r.Address()) == nil {
removed = append(removed, r.Address())
metrics.ReplicaRemoved()
// Unregister the metrics collector for the removed replica
lb.unregisterReplicaMetricsCollector(r)
// Close handlers for retired replicas
l.WithFields(log.Fields{"db_replica_addr": r.Address()}).Info("closing connection handler for retired replica")
if err := r.Close(); err != nil {
err = fmt.Errorf("failed to close retired replica %q connection: %w", r.Address(), err)
result = multierror.Append(result, err)
errortracking.Capture(err, errortracking.WithContext(ctx), errortracking.WithStackTrace())
}
}
}
l.WithFields(logrus.Fields{
"added_hosts": strings.Join(added, ","),
"removed_hosts": strings.Join(removed, ","),
}).Info("updating replicas list")
metrics.ReplicaPoolSize(len(outputReplicas))
lb.replicas = outputReplicas
return result.ErrorOrNil()
}
func dbByAddress(dbs []*DB, addr string) *DB {
for _, r := range dbs {
if r.Address() == addr {
return r
}
}
return nil
}
// StartPoolRefresh synchronously refreshes the list of replica servers in the configured interval.
func (lb *DBLoadBalancer) StartPoolRefresh(ctx context.Context) error {
// If the check interval was set to zero (no recurring checks) or the resolver is not set (service discovery
// was not enabled), then exit early as there is nothing to do
if lb.replicaCheckInterval == 0 || lb.resolver == nil {
return nil
}
l := lb.logger(ctx)
t := time.NewTicker(lb.replicaCheckInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
l.WithFields(log.Fields{"interval_ms": lb.replicaCheckInterval.Milliseconds()}).
Info("scheduled refresh of replicas list")
lb.throttledPoolResolver.Resolve(ctx)
}
}
}
// StartLagCheck runs a background goroutine to check lag for all replicas.
func (lb *DBLoadBalancer) StartLagCheck(ctx context.Context) error {
// If the check interval was set to zero (no recurring checks) or there are no replicas to check,
// then exit early as there is nothing to do
if lb.replicaCheckInterval == 0 || len(lb.replicas) == 0 {
return nil
}
l := lb.logger(ctx)
t := time.NewTicker(lb.replicaCheckInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
l.Info("checking replication lag for all replicas")
// Get current primary LSN
primaryLSN, err := lb.primaryLSN(ctx)
if err != nil {
l.WithError(err).Error("failed to query primary LSN")
continue
}
// Check lag for each replica
lb.replicaMutex.Lock()
replicas := lb.replicas
lb.replicaMutex.Unlock()
for _, replica := range replicas {
replicaAddr := replica.Address()
l = l.WithFields(log.Fields{"replica_addr": replicaAddr})
if err := lb.lagTracker.Check(ctx, primaryLSN, replica); err != nil {
l.WithError(err).Error("failed to check database replica lag")
}
}
}
}
}
// removeReplica removes a replica from the pool and closes its connection.
func (lb *DBLoadBalancer) removeReplica(ctx context.Context, r *DB) {
replicaAddr := r.Address()
l := lb.logger(ctx).WithFields(log.Fields{"db_replica_addr": replicaAddr})
lb.replicaMutex.Lock()
defer lb.replicaMutex.Unlock()
for i, replica := range lb.replicas {
if replica.Address() == replicaAddr {
l.Warn("removing replica from pool")
lb.replicas = append(lb.replicas[:i], lb.replicas[i+1:]...)
lb.unregisterReplicaMetricsCollector(r)
metrics.ReplicaRemoved()
metrics.ReplicaPoolSize(len(lb.replicas))
if err := r.Close(); err != nil {
l.WithError(err).Error("error closing retired replica connection")
}
break
}
}
}
// NewDBLoadBalancer initializes a DBLoadBalancer with primary and replica connections. An error is returned if failed
// to connect to the primary server. Failures to connect to replica server(s) are handled gracefully, that is, logged,
// reported and ignored. This is to prevent halting the app start, as it can function with the primary server only.
// DBLoadBalancer.StartReplicaChecking can be used to periodically refresh the list of replicas, potentially leading to
// the self-healing of transient connection failures during this initialization.
func NewDBLoadBalancer(ctx context.Context, primaryDSN *DSN, opts ...Option) (*DBLoadBalancer, error) {
config := applyOptions(opts)
lb := &DBLoadBalancer{
active: config.loadBalancing.active,
primaryDSN: primaryDSN,
connector: config.loadBalancing.connector,
resolver: config.loadBalancing.resolver,
fixedHosts: config.loadBalancing.hosts,
replicaOpenOpts: opts,
replicaCheckInterval: config.loadBalancing.replicaCheckInterval,
lsnCache: config.loadBalancing.lsnStore,
metricsEnabled: config.metricsEnabled,
promRegisterer: config.promRegisterer,
replicaPromCollectors: make(map[string]prometheus.Collector),
}
// Initialize the replicas liveness prober with a callback to retire unhealthy hosts
lb.livenessProber = NewLivenessProber(WithUnhealthyCallback(lb.removeReplica))
// Initialize the throttled replica pool resolver
lb.throttledPoolResolver = NewThrottledPoolResolver(
WithMinInterval(minResolveReplicasInterval),
WithResolveFunction(lb.ResolveReplicas),
)
// Initialize the replica lag tracker using the same interval as replica checking
lb.lagTracker = NewReplicaLagTracker(
WithLagCheckInterval(config.loadBalancing.replicaCheckInterval),
)
primary, err := lb.connector.Open(ctx, primaryDSN, opts...)
if err != nil {
return nil, fmt.Errorf("failed to open primary database connection: %w", err)
}
primary.errorProcessor = lb
lb.primary = primary
// Conditionally register metrics for the primary database handle
if lb.metricsEnabled {
lb.promRegisterer.MustRegister(lb.metricsCollector(primary, HostTypePrimary))
}
if lb.active {
ctx, cancel := context.WithTimeout(ctx, InitReplicaResolveTimeout)
defer cancel()
if err := lb.ResolveReplicas(ctx); err != nil {
lb.logger(ctx).WithError(err).Error("failed to resolve database load balancing replicas")
errortracking.Capture(err, errortracking.WithContext(ctx), errortracking.WithStackTrace())
}
}
return lb, nil
}
// Primary returns the primary database handler.
func (lb *DBLoadBalancer) Primary() *DB {
return lb.primary
}
// Replica returns a round-robin elected replica database handler. If no replicas are configured, then the primary
// database handler is returned.
func (lb *DBLoadBalancer) Replica(ctx context.Context) *DB {
lb.replicaMutex.Lock()
defer lb.replicaMutex.Unlock()
if len(lb.replicas) == 0 {
lb.logger(ctx).Info("no replicas available, falling back to primary")
return lb.primary
}
replica := lb.replicas[lb.replicaIndex]
lb.replicaIndex = (lb.replicaIndex + 1) % len(lb.replicas)
return replica
}
// Replicas returns all replica database handlers currently in the pool.
func (lb *DBLoadBalancer) Replicas() []*DB {
return lb.replicas
}
// Close closes all database connections managed by the DBLoadBalancer.
func (lb *DBLoadBalancer) Close() error {
var result *multierror.Error
if err := lb.primary.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed closing primary connection: %w", err))
}
for _, replica := range lb.replicas {
if err := replica.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed closing replica %q connection: %w", replica.Address(), err))
}
}
return result.ErrorOrNil()
}
// RecordLSN queries the current primary database WAL insert Log Sequence Number (LSN) and records it in the LSN cache
// in association with a given models.Repository.
// See https://gitlab.com/gitlab-org/container-registry/-/blob/master/docs/spec/gitlab/database-load-balancing.md?ref_type=heads#primary-sticking
func (lb *DBLoadBalancer) RecordLSN(ctx context.Context, r *models.Repository) error {
if lb.lsnCache == nil {
return fmt.Errorf("LSN cache is not configured")
}
lsn, err := lb.primaryLSN(ctx)
if err != nil {
return fmt.Errorf("failed to query current WAL insert LSN: %w", err)
}
if err := lb.lsnCache.SetLSN(ctx, r, lsn); err != nil {
return fmt.Errorf("failed to cache WAL insert LSN: %w", err)
}
lb.logger(ctx).WithFields(log.Fields{"repository": r.Path, "lsn": lsn}).Info("current WAL insert LSN recorded")
return nil
}
// UpToDateReplica returns the most suitable database connection handle for serving a read request for a given
// models.Repository based on the last recorded primary Log Sequence Number (LSN) for that same repository. All errors
// during this method execution are handled gracefully with a fallback to the primary connection handle.
// Relevant errors during query execution should be handled _explicitly_ by the caller. For example, if the caller
// obtained a connection handle for replica `R` at time `T`, and `R` is retired from the load balancer pool at `T+1`,
// then any queries attempted against `R` after `T+1` would result in a `sql: database is closed` error raised by the
// `database/sql` package. In such case, it is the caller's responsibility to fall back to Primary and retry.
func (lb *DBLoadBalancer) UpToDateReplica(ctx context.Context, r *models.Repository) *DB {
primary := lb.primary
primary.errorProcessor = lb
if !lb.active {
return lb.primary
}
l := lb.logger(ctx).WithFields(log.Fields{"repository": r.Path})
if lb.lsnCache == nil {
l.Info("no LSN cache configured, falling back to primary")
metrics.PrimaryFallbackNoCache()
return lb.primary
}
// Do not let the LSN cache lookup and subsequent DB comparison (total) take more than upToDateReplicaTimeout,
// effectively enforcing a graceful fallback to the primary database if so.
ctx, cancel := context.WithTimeout(ctx, upToDateReplicaTimeout)
defer cancel()
// Get the next replica using round-robin. For simplicity, on the first iteration of DLB, we simply check against
// the first returned replica candidate, not against (potentially) all replicas in the pool. If the elected replica
// candidate is behind the previously recorded primary LSN, then we simply fall back to the connection handle for
// the primary database.
replica := lb.Replica(ctx)
if replica == lb.primary {
metrics.PrimaryFallbackNoReplica()
return lb.primary
}
l = l.WithFields(log.Fields{"db_replica_addr": replica.Address()})
// Fetch the primary LSN from cache
primaryLSN, err := lb.lsnCache.GetLSN(ctx, r)
if err != nil {
l.WithError(err).Error("failed to fetch primary LSN from cache, falling back to primary")
metrics.PrimaryFallbackError()
return lb.primary
}
// If the record does not exist in cache, the replica is considered suitable
if primaryLSN == "" {
metrics.LSNCacheMiss()
l.Info("no primary LSN found in cache, replica is eligible")
metrics.ReplicaTarget()
return replica
}
metrics.LSNCacheHit()
l = l.WithFields(log.Fields{"primary_lsn": primaryLSN})
// Query to check if the candidate replica is up-to-date with the primary LSN
defer metrics.InstrumentQuery("lb_replica_up_to_date")()
query := `
WITH replica_lsn AS (
SELECT pg_last_wal_replay_lsn () AS lsn
)
SELECT
pg_wal_lsn_diff ($1::pg_lsn, lsn) <= 0
FROM
replica_lsn`
var upToDate bool
if err := replica.QueryRowContext(ctx, query, primaryLSN).Scan(&upToDate); err != nil {
l.WithError(err).Error("failed to calculate LSN diff, falling back to primary")
metrics.PrimaryFallbackError()
return lb.primary
}
if upToDate {
l.Info("replica is up-to-date")
metrics.ReplicaTarget()
return replica
}
l.Info("replica is not up-to-date, falling back to primary")
metrics.PrimaryFallbackNotUpToDate()
return lb.primary
}
// TypeOf returns the type of the provided *DB instance: HostTypePrimary, HostTypeReplica or HostTypeUnknown.
func (lb *DBLoadBalancer) TypeOf(db *DB) string {
lb.replicaMutex.Lock()
defer lb.replicaMutex.Unlock()
if db == lb.primary {
return HostTypePrimary
}
for _, replica := range lb.replicas {
if db == replica {
return HostTypeReplica
}
}
return HostTypeUnknown
}
// primaryLSN returns the primary database's current write location
func (lb *DBLoadBalancer) primaryLSN(ctx context.Context) (string, error) {
defer metrics.InstrumentQuery("lb_primary_lsn")()
var lsn string
query := "SELECT pg_current_wal_insert_lsn()::text AS location"
if err := lb.primary.QueryRowContext(ctx, query).Scan(&lsn); err != nil {
return "", err
}
return lsn, nil
}
// QueryBuilder helps in building SQL queries with parameters.
type QueryBuilder struct {
sql strings.Builder
params []any
newLine bool
}
func NewQueryBuilder() *QueryBuilder {
return &QueryBuilder{
params: make([]any, 0),
}
}
// Build takes the given sql string replaces any ? with the equivalent indexed
// parameter and appends elems to the args slice.
func (qb *QueryBuilder) Build(q string, qArgs ...any) error {
placeholderCount := strings.Count(q, "?")
if placeholderCount != len(qArgs) {
return fmt.Errorf(
"number of placeholders (%d) in query %q does not match the number of arguments (%d) passed",
placeholderCount, q, len(qArgs),
)
}
if q == "" {
return nil
}
for _, elem := range qArgs {
qb.params = append(qb.params, elem)
paramName := fmt.Sprintf("$%d", len(qb.params))
q = strings.Replace(q, "?", paramName, 1)
}
q = strings.Trim(q, " \t")
newLine := q[len(q)-1] == '\n'
// If the query ends in a newline, don't add a space before it. Adding a
// space is a convenience for chaining expressions.
switch {
case qb.newLine, qb.sql.Len() == 0, q == "\n":
_, _ = qb.sql.WriteString(q)
default:
_, _ = fmt.Fprintf(&qb.sql, " %s", q)
}
qb.newLine = newLine
return nil
}
// WrapIntoSubqueryOf wraps existing query as a subquery of the given query.
// The outerQuery param needs to have a single %s where the current query will
// be copied into.
func (qb *QueryBuilder) WrapIntoSubqueryOf(outerQuery string) error {
if !strings.Contains(outerQuery, "%s") || strings.Count(outerQuery, "%s") != 1 {
return fmt.Errorf("outerQuery must contain exactly one %%s placeholder. Query: %v", outerQuery)
}
newSQL := strings.Builder{}
_, _ = fmt.Fprintf(&newSQL, outerQuery, qb.sql.String())
qb.sql = newSQL
return nil
}
// SQL returns the rendered SQL query.
func (qb *QueryBuilder) SQL() string {
if qb.newLine {
return strings.TrimRight(qb.sql.String(), "\n")
}
return qb.sql.String()
}
// Params returns the slice of query literals to be used in the SQL query.
func (qb *QueryBuilder) Params() []any {
ret := make([]any, len(qb.params))
copy(ret, qb.params)
return ret
}
// IsInRecovery checks if a provided database is in read-only mode.
func IsInRecovery(ctx context.Context, db *DB) (bool, error) {
var inRecovery bool
// https://www.postgresql.org/docs/9.0/functions-admin.html#:~:text=Table%209%2D58.%20Recovery%20Information%20Functions
query := `SELECT pg_is_in_recovery()`
err := db.QueryRowContext(ctx, query).Scan(&inRecovery)
return inRecovery, err
}
// ReplicaLagInfo stores lag information for a replica
type ReplicaLagInfo struct {
Address string
TimeLag time.Duration
BytesLag int64
LastChecked time.Time
}
// LagTracker represents a component that can track database replication lag.
type LagTracker interface {
Check(ctx context.Context, primaryLSN string, replica *DB) error
}
// ReplicaLagTracker manages replication lag tracking
type ReplicaLagTracker struct {
sync.Mutex
lagInfo map[string]*ReplicaLagInfo
checkInterval time.Duration
}
// ReplicaLagTrackerOption configures a ReplicaLagTracker.
type ReplicaLagTrackerOption func(*ReplicaLagTracker)
// WithLagCheckInterval sets the interval for checking replication lag.
func WithLagCheckInterval(interval time.Duration) ReplicaLagTrackerOption {
return func(t *ReplicaLagTracker) {
if interval > 0 {
t.checkInterval = interval
}
}
}
// NewReplicaLagTracker creates a new ReplicaLagTracker with the given options.
func NewReplicaLagTracker(opts ...ReplicaLagTrackerOption) *ReplicaLagTracker {
tracker := &ReplicaLagTracker{
lagInfo: make(map[string]*ReplicaLagInfo),
checkInterval: defaultReplicaCheckInterval,
}
for _, opt := range opts {
opt(tracker)
}
return tracker
}
// Get returns the replication lag info for a replica
func (t *ReplicaLagTracker) Get(replicaAddr string) *ReplicaLagInfo {
t.Lock()
defer t.Unlock()
info, exists := t.lagInfo[replicaAddr]
if !exists {
return nil
}
// Return a copy to avoid race conditions
lagInfo := *info
return &lagInfo
}
// set updates the replication lag information for a replica
func (t *ReplicaLagTracker) set(_ context.Context, db *DB, timeLag time.Duration, bytesLag int64) {
t.Lock()
defer t.Unlock()
addr := db.Address()
now := time.Now()
info, exists := t.lagInfo[addr]
if !exists {
info = &ReplicaLagInfo{
Address: addr,
}
t.lagInfo[addr] = info
}
info.TimeLag = timeLag
info.BytesLag = bytesLag
info.LastChecked = now
metrics.ReplicaLagBytes(addr, float64(bytesLag))
metrics.ReplicaLagSeconds(addr, timeLag.Seconds())
}
// CheckBytesLag retrieves the data-based replication lag for a replica
func (*ReplicaLagTracker) CheckBytesLag(ctx context.Context, primaryLSN string, replica *DB) (int64, error) {
defer metrics.InstrumentQuery("lb_check_bytes_lag")()
queryCtx, cancel := context.WithTimeout(ctx, replicaLagCheckTimeout)
defer cancel()
// Calculate bytes lag on the replica using the provided primary LSN
var bytesLag int64
query := `SELECT pg_wal_lsn_diff($1, pg_last_wal_replay_lsn())::bigint AS diff`
err := replica.QueryRowContext(queryCtx, query, primaryLSN).Scan(&bytesLag)
if err != nil {
return 0, fmt.Errorf("failed to calculate replica bytes lag: %w", err)
}
return bytesLag, nil
}
// CheckTimeLag retrieves the time-based replication lag for a replica
func (*ReplicaLagTracker) CheckTimeLag(ctx context.Context, replica *DB) (time.Duration, error) {
defer metrics.InstrumentQuery("lb_check_time_lag")()
queryCtx, cancel := context.WithTimeout(ctx, replicaLagCheckTimeout)
defer cancel()
query := `SELECT EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp()))::float AS lag`
var timeLagSeconds float64
err := replica.QueryRowContext(queryCtx, query).Scan(&timeLagSeconds)
if err != nil {
return 0, fmt.Errorf("failed to check replica time lag: %w", err)
}
return time.Duration(timeLagSeconds * float64(time.Second)), nil
}
// Check checks replication lag for a specific replica and stores it.
func (t *ReplicaLagTracker) Check(ctx context.Context, primaryLSN string, replica *DB) error {
l := log.GetLogger(log.WithContext(ctx)).WithFields(log.Fields{
"db_replica_addr": replica.Address(),
})
timeLag, err := t.CheckTimeLag(ctx, replica)
if err != nil {
l.WithError(err).Error("failed to check time-based replication lag")
return err
}
bytesLag, err := t.CheckBytesLag(ctx, primaryLSN, replica)
if err != nil {
l.WithError(err).Error("failed to check data-based replication lag")
return err
}
// Log at appropriate level based on max thresholds
l = l.WithFields(log.Fields{"lag_time_s": timeLag.Seconds(), "lag_bytes": bytesLag})
if timeLag > MaxReplicaLagTime {
l.Warn("replica time-based replication lag above max threshold")
}
if bytesLag > MaxReplicaLagBytes {
l.Warn("replica data-based replication lag above max threshold")
}
if timeLag <= MaxReplicaLagTime && bytesLag <= MaxReplicaLagBytes {
l.Info("replica replication lag below max thresholds")
}
t.set(ctx, replica, timeLag, bytesLag)
return nil
}