monitored_cache.go (107 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cloudsqlconn import ( "context" "sync" "sync/atomic" "time" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/instance" ) // monitoredCache is a wrapper around a connectionInfoCache that tracks the // number of connections to the associated instance. type monitoredCache struct { openConnsCount *uint64 cn instance.ConnName resolver instance.ConnectionNameResolver logger debug.ContextLogger // domainNameTicker periodically checks any domain names to see if they // changed. domainNameTicker *time.Ticker closedCh chan struct{} mu sync.Mutex openConns []*instrumentedConn closed bool connectionInfoCache } func newMonitoredCache( ctx context.Context, cache connectionInfoCache, cn instance.ConnName, failoverPeriod time.Duration, resolver instance.ConnectionNameResolver, logger debug.ContextLogger) *monitoredCache { c := &monitoredCache{ openConnsCount: new(uint64), closedCh: make(chan struct{}), cn: cn, resolver: resolver, logger: logger, connectionInfoCache: cache, } if cn.HasDomainName() { c.domainNameTicker = time.NewTicker(failoverPeriod) go func() { for { select { case <-c.domainNameTicker.C: c.purgeClosedConns() c.checkDomainName(ctx) case <-c.closedCh: return } } }() } return c } func (c *monitoredCache) isClosed() bool { c.mu.Lock() defer c.mu.Unlock() return c.closed } func (c *monitoredCache) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return nil } c.closed = true close(c.closedCh) if c.domainNameTicker != nil { c.domainNameTicker.Stop() } if atomic.LoadUint64(c.openConnsCount) > 0 { for _, socket := range c.openConns { if !socket.isClosed() { _ = socket.Close() // force socket closed, ok to ignore error. } } atomic.StoreUint64(c.openConnsCount, 0) } return c.connectionInfoCache.Close() } func (c *monitoredCache) purgeClosedConns() { c.mu.Lock() defer c.mu.Unlock() var open []*instrumentedConn for _, s := range c.openConns { if !s.isClosed() { open = append(open, s) } } c.openConns = open } func (c *monitoredCache) checkDomainName(ctx context.Context) { if !c.cn.HasDomainName() { return } newCn, err := c.resolver.Resolve(ctx, c.cn.DomainName()) if err != nil { // The domain name could not be resolved. c.logger.Debugf(ctx, "domain name %s for instance %s did not resolve, "+ "closing all connections: %v", c.cn.DomainName(), c.cn.Name(), err) c.Close() } if newCn != c.cn { // The instance changed. c.logger.Debugf(ctx, "domain name %s changed from %s to %s, "+ "closing all connections.", c.cn.DomainName(), c.cn.Name(), newCn.Name()) c.Close() } }