dialer.go (556 lines of code) (raw):

// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package alloydbconn import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" _ "embed" "encoding/binary" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" "cloud.google.com/go/alloydb/connectors/apiv1alpha/connectorspb" "cloud.google.com/go/alloydbconn/debug" "cloud.google.com/go/alloydbconn/errtype" "cloud.google.com/go/alloydbconn/internal/alloydb" "cloud.google.com/go/alloydbconn/internal/tel" "github.com/google/uuid" "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/protobuf/proto" alloydbadmin "cloud.google.com/go/alloydb/apiv1alpha" telv2 "cloud.google.com/go/alloydbconn/internal/tel/v2" ) const ( // defaultTCPKeepAlive is the default keep alive value used on connections // to a AlloyDB instance defaultTCPKeepAlive = 30 * time.Second // serverProxyPort is the port the server-side proxy receives connections on. serverProxyPort = "5433" // ioTimeout is the maximum amount of time to wait before aborting a // metadata exhange ioTimeout = 30 * time.Second // metricShutdownTimeout is the maximum amount of time to wait to flush any // remaining metrics when the dialer closes. metricShutdownTimeout = 3 * time.Second ) var ( // ErrDialerClosed is used when a caller invokes Dial after closing the // Dialer. ErrDialerClosed = errors.New("alloydbconn: dialer is closed") // versionString indicates the version of this library. //go:embed version.txt versionString string userAgent = "alloydb-go-connector/" + strings.TrimSpace(versionString) ) // keyGenerator encapsulates the details of RSA key generation to provide lazy // generation, custom keys, or a default RSA generator. type keyGenerator struct { once sync.Once key *rsa.PrivateKey err error genFunc func() (*rsa.PrivateKey, error) } // newKeyGenerator initializes a keyGenerator that will (in order): // - always return the RSA key if one is provided, or // - generate an RSA key lazily when it's requested, or // - (default) immediately generate an RSA key as part of the initializer. func newKeyGenerator( k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error), ) (*keyGenerator, error) { g := &keyGenerator{genFunc: genFunc} switch { case k != nil: // If the caller has provided a key, initialize the key and consume the // sync.Once now. g.once.Do(func() { g.key, g.err = k, nil }) case lazy: // If lazy refresh is enabled, do nothing and wait for the call to // rsaKey. default: // If no key has been provided and lazy refresh isn't enabled, generate // the key and consume the sync.Once now. g.once.Do(func() { g.key, g.err = g.genFunc() }) } return g, g.err } // rsaKey will generate an RSA key if one is not already cached. Otherwise, it // will return the cached key. func (g *keyGenerator) rsaKey() (*rsa.PrivateKey, error) { g.once.Do(func() { g.key, g.err = g.genFunc() }) return g.key, g.err } type connectionInfoCache interface { ConnectionInfo(context.Context) (alloydb.ConnectionInfo, error) ForceRefresh() io.Closer } // monitoredCache is a wrapper around a connectionInfoCache that tracks the // number of connections to the associated instance. type monitoredCache struct { openConns *uint64 connectionInfoCache } // A Dialer is used to create connections to AlloyDB instance. // // Use NewDialer to initialize a Dialer. type Dialer struct { lock sync.RWMutex cache map[alloydb.InstanceURI]monitoredCache keyGenerator *keyGenerator refreshTimeout time.Duration // closed reports if the dialer has been closed. closed chan struct{} // lazyRefresh determines what kind of caching is used for ephemeral // certificates. When lazyRefresh is true, the dialer will use a lazy // cache, refresh certificates only when a connection attempt needs a fresh // certificate. Otherwise, a refresh ahead cache will be used. The refresh // ahead cache assumes a background goroutine may run consistently. lazyRefresh bool // disableMetadataExchange is a temporary addition to help clients who // cannot use the metadata exchange yet. In future versions, this field // should be removed. disableMetadataExchange bool // disableBuiltInMetrics turns the internal metric export into a no-op. disableBuiltInMetrics bool staticConnInfo io.Reader client *alloydbadmin.AlloyDBAdminClient // clientOpts are options for all Google Cloud API clients. There should be // no AlloyDB-specific configuration in these options. clientOpts []option.ClientOption logger debug.ContextLogger // defaultDialCfg holds the constructor level DialOptions, so that it can // be copied and mutated by the Dial function. defaultDialCfg dialCfg // dialerID uniquely identifies a Dialer. Used for monitoring purposes, // *only* when a client has configured OpenCensus exporters. dialerID string metricsMu sync.Mutex metricRecorders map[alloydb.InstanceURI]telv2.MetricRecorder // dialFunc is the function used to connect to the address on the named // network. By default it is golang.org/x/net/proxy#Dial. dialFunc func(cxt context.Context, network, addr string) (net.Conn, error) useIAMAuthN bool iamTokenSource oauth2.TokenSource userAgent string buffer *buffer } type nullLogger struct{} func (nullLogger) Debugf(context.Context, string, ...any) {} // NewDialer creates a new Dialer. // // Initial calls to NewDialer make take longer than normal because generation of an // RSA keypair is performed. Calls with a WithRSAKeyPair DialOption or after a default // RSA keypair is generated will be faster. func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { cfg := &dialerConfig{ refreshTimeout: alloydb.RefreshTimeout, dialFunc: proxy.Dial, logger: nullLogger{}, userAgents: []string{userAgent}, } for _, opt := range opts { opt(cfg) if cfg.err != nil { return nil, cfg.err } } if cfg.disableMetadataExchange && cfg.useIAMAuthN { return nil, errors.New("incompatible options: WithOptOutOfAdvancedConnection " + "check cannot be used with WithIAMAuthN") } userAgent := strings.Join(cfg.userAgents, " ") // Add user agent to the end to make sure it's not overridden. cfg.clientOpts = append(cfg.clientOpts, option.WithUserAgent(userAgent)) // If no token source is configured, use ADC's token source. ts := cfg.tokenSource if ts == nil { var err error ts, err = google.DefaultTokenSource(ctx, CloudPlatformScope) if err != nil { return nil, err } } cOpts := append(cfg.alloydbClientOpts, cfg.clientOpts...) client, err := alloydbadmin.NewAlloyDBAdminRESTClient(ctx, cOpts...) if err != nil { return nil, fmt.Errorf("failed to create AlloyDB Admin API client: %v", err) } dialCfg := dialCfg{ ipType: alloydb.PrivateIP, tcpKeepAlive: defaultTCPKeepAlive, } for _, opt := range cfg.dialOpts { opt(&dialCfg) } if err := tel.InitMetrics(); err != nil { return nil, err } dialerID := uuid.New().String() g, err := newKeyGenerator(cfg.rsaKey, cfg.lazyRefresh, func() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2048) }) if err != nil { return nil, err } d := &Dialer{ closed: make(chan struct{}), cache: make(map[alloydb.InstanceURI]monitoredCache), lazyRefresh: cfg.lazyRefresh, disableMetadataExchange: cfg.disableMetadataExchange, disableBuiltInMetrics: cfg.disableBuiltInTelemetry, staticConnInfo: cfg.staticConnInfo, keyGenerator: g, refreshTimeout: cfg.refreshTimeout, client: client, clientOpts: cfg.clientOpts, logger: cfg.logger, defaultDialCfg: dialCfg, dialerID: dialerID, metricRecorders: map[alloydb.InstanceURI]telv2.MetricRecorder{}, dialFunc: cfg.dialFunc, useIAMAuthN: cfg.useIAMAuthN, iamTokenSource: ts, userAgent: userAgent, buffer: newBuffer(), } return d, nil } // metricRecorder does a lazy initialization of the metric exporter. func (d *Dialer) metricRecorder(ctx context.Context, inst alloydb.InstanceURI) telv2.MetricRecorder { d.metricsMu.Lock() defer d.metricsMu.Unlock() if mr, ok := d.metricRecorders[inst]; ok { return mr } cfg := telv2.Config{ Enabled: !d.disableBuiltInMetrics, Version: versionString, ClientID: d.dialerID, ProjectID: inst.Project(), Location: inst.Region(), Cluster: inst.Cluster(), Instance: inst.Name(), } mr := telv2.NewMetricRecorder(ctx, d.logger, cfg, d.clientOpts...) d.metricRecorders[inst] = mr return mr } // Dial returns a net.Conn connected to the specified AlloyDB instance. The // instance argument must be the instance's URI, which is in the format // projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE> func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) (conn net.Conn, err error) { select { case <-d.closed: return nil, ErrDialerClosed default: } inst, err := alloydb.ParseInstURI(instance) if err != nil { return nil, err } mr := d.metricRecorder(ctx, inst) var ( startTime = time.Now() endDial tel.EndSpanFunc attrs = telv2.Attributes{ IAMAuthN: d.useIAMAuthN, UserAgent: d.userAgent, RefreshType: telv2.RefreshAheadType, } ) if d.lazyRefresh { attrs.RefreshType = telv2.RefreshLazyType } ctx, endDial = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn.Dial", tel.AddInstanceName(instance), tel.AddDialerID(d.dialerID), ) defer func() { go tel.RecordDialError(context.Background(), instance, d.dialerID, err) go mr.RecordDialCount(ctx, attrs) endDial(err) }() cfg := d.defaultDialCfg for _, opt := range opts { opt(&cfg) } var endInfo tel.EndSpanFunc ctx, endInfo = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.InstanceInfo") cache, cacheHit, err := d.connectionInfoCache(ctx, inst, mr) attrs.CacheHit = cacheHit if err != nil { attrs.DialStatus = telv2.DialCacheError endInfo(err) return nil, err } ci, err := cache.ConnectionInfo(ctx) if err != nil { attrs.DialStatus = telv2.DialCacheError d.removeCached(ctx, inst, cache, err) endInfo(err) return nil, err } endInfo(err) // If the client certificate has expired (as when the computer goes to // sleep, and the refresh cycle cannot run), force a refresh immediately. // The TLS handshake will not fail on an expired client certificate. It's // not until the first read where the client cert error will be surfaced. // So check that the certificate is valid before proceeding. if invalidClientCert(ctx, inst, d.logger, ci.Expiration) { d.logger.Debugf(ctx, "[%v] Refreshing certificate now", inst.String()) cache.ForceRefresh() // Block on refreshed connection info ci, err = cache.ConnectionInfo(ctx) if err != nil { d.removeCached(ctx, inst, cache, err) attrs.DialStatus = telv2.DialCacheError return nil, err } } addr, ok := ci.IPAddrs[cfg.ipType] if !ok { d.removeCached(ctx, inst, cache, err) err := errtype.NewConfigError( fmt.Sprintf("instance does not have IP of type %q", cfg.ipType), inst.String(), ) attrs.DialStatus = telv2.DialUserError return nil, err } var connectEnd tel.EndSpanFunc ctx, connectEnd = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.Connect") defer func() { connectEnd(err) }() hostPort := net.JoinHostPort(addr, serverProxyPort) f := d.dialFunc if cfg.dialFunc != nil { f = cfg.dialFunc } d.logger.Debugf(ctx, "[%v] Dialing %v", inst.String(), hostPort) conn, err = f(ctx, "tcp", hostPort) if err != nil { d.logger.Debugf(ctx, "[%v] Dialing %v failed: %v", inst.String(), hostPort, err) // refresh the instance info in case it caused the connection failure cache.ForceRefresh() attrs.DialStatus = telv2.DialTCPError return nil, errtype.NewDialError("failed to dial", inst.String(), err) } if c, ok := conn.(*net.TCPConn); ok { if err := c.SetKeepAlive(true); err != nil { attrs.DialStatus = telv2.DialTCPError return nil, errtype.NewDialError("failed to set keep-alive", inst.String(), err) } if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil { attrs.DialStatus = telv2.DialTCPError return nil, errtype.NewDialError("failed to set keep-alive period", inst.String(), err) } } c := &tls.Config{ Certificates: []tls.Certificate{ci.ClientCert}, RootCAs: ci.RootCAs, // The PSC, private, and public IP all appear in the certificate as // SAN. Use the server name that corresponds to the requested // connection path. ServerName: addr, MinVersion: tls.VersionTLS13, } tlsConn := tls.Client(conn, c) if err := tlsConn.HandshakeContext(ctx); err != nil { d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", inst.String(), err) // refresh the instance info in case it caused the handshake failure cache.ForceRefresh() _ = tlsConn.Close() // best effort close attempt attrs.DialStatus = telv2.DialTLSError return nil, errtype.NewDialError("handshake failed", inst.String(), err) } if !d.disableMetadataExchange { // The metadata exchange must occur after the TLS connection is established // to avoid leaking sensitive information. err = d.metadataExchange(tlsConn) if err != nil { _ = tlsConn.Close() // best effort close attempt attrs.DialStatus = telv2.DialMDXError return nil, err } } attrs.DialStatus = telv2.DialSuccess latency := time.Since(startTime).Milliseconds() go func() { n := atomic.AddUint64(cache.openConns, 1) tel.RecordOpenConnections(ctx, int64(n), d.dialerID, inst.String()) tel.RecordDialLatency(ctx, instance, d.dialerID, latency) mr.RecordOpenConnection(ctx, attrs) mr.RecordDialLatency(ctx, latency, attrs) }() return newInstrumentedConn(tlsConn, mr, attrs, func() { n := atomic.AddUint64(cache.openConns, ^uint64(0)) tel.RecordOpenConnections(context.Background(), int64(n), d.dialerID, inst.String()) mr.RecordClosedConnection(context.Background(), attrs) }, d.dialerID, inst.String()), nil } // removeCached stops all background refreshes and deletes the connection // info cache from the map of caches. func (d *Dialer) removeCached( ctx context.Context, i alloydb.InstanceURI, c connectionInfoCache, err error, ) { d.logger.Debugf( ctx, "[%v] Removing connection info from cache: %v", i.String(), err, ) d.lock.Lock() defer d.lock.Unlock() c.Close() delete(d.cache, i) } func invalidClientCert( ctx context.Context, inst alloydb.InstanceURI, l debug.ContextLogger, expiration time.Time, ) bool { now := time.Now().UTC() notAfter := expiration.UTC() invalid := now.After(notAfter) l.Debugf( ctx, "[%v] Now = %v, Current cert expiration = %v", inst.String(), now.Format(time.RFC3339), notAfter.Format(time.RFC3339), ) l.Debugf(ctx, "[%v] Cert is valid = %v", inst.String(), !invalid) return invalid } // metadataExchange sends metadata about the connection prior to the database // protocol taking over. The exchange consists of four steps: // // 1. Prepare a MetadataExchangeRequest including the IAM Principal's OAuth2 // token, the user agent, and the requested authentication type. // // 2. Write the size of the message as a big endian uint32 (4 bytes) to the // server followed by the marshaled message. The length does not include the // initial four bytes. // // 3. Read a big endian uint32 (4 bytes) from the server. This is the // MetadataExchangeResponse message length and does not include the initial // four bytes. // // 4. Unmarshal the response using the message length in step 3. If the // response is not OK, return the response's error. If there is no error, the // metadata exchange has succeeded and the connection is complete. // // Subsequent interactions with the server use the database protocol. func (d *Dialer) metadataExchange(conn net.Conn) error { tok, err := d.iamTokenSource.Token() if err != nil { return err } authType := connectorspb.MetadataExchangeRequest_DB_NATIVE if d.useIAMAuthN { authType = connectorspb.MetadataExchangeRequest_AUTO_IAM } req := &connectorspb.MetadataExchangeRequest{ UserAgent: d.userAgent, AuthType: authType, Oauth2Token: tok.AccessToken, } m, err := proto.Marshal(req) if err != nil { return err } b := d.buffer.get() defer d.buffer.put(b) buf := *b reqSize := proto.Size(req) binary.BigEndian.PutUint32(buf, uint32(reqSize)) buf = append(buf[:4], m...) // Set IO deadline before write err = conn.SetDeadline(time.Now().Add(ioTimeout)) if err != nil { return err } defer conn.SetDeadline(time.Time{}) _, err = conn.Write(buf) if err != nil { return err } // Reset IO deadline before read err = conn.SetDeadline(time.Now().Add(ioTimeout)) if err != nil { return err } defer conn.SetDeadline(time.Time{}) buf = buf[:4] _, err = conn.Read(buf) if err != nil { return err } respSize := binary.BigEndian.Uint32(buf) resp := buf[:respSize] _, err = conn.Read(resp) if err != nil { return err } var mdxResp connectorspb.MetadataExchangeResponse err = proto.Unmarshal(resp, &mdxResp) if err != nil { return err } if mdxResp.GetResponseCode() != connectorspb.MetadataExchangeResponse_OK { return errors.New(mdxResp.GetError()) } return nil } const maxMessageSize = 16 * 1024 // 16 kb type buffer struct { pool sync.Pool } func newBuffer() *buffer { return &buffer{ pool: sync.Pool{ New: func() any { buf := make([]byte, maxMessageSize) return &buf }, }, } } func (b *buffer) get() *[]byte { return b.pool.Get().(*[]byte) } func (b *buffer) put(buf *[]byte) { b.pool.Put(buf) } // newInstrumentedConn initializes an instrumentedConn that on closing will // decrement the number of open connects and record the result. func newInstrumentedConn(conn net.Conn, mr telv2.MetricRecorder, a telv2.Attributes, closeFunc func(), dialerID, instance string) *instrumentedConn { return &instrumentedConn{ Conn: conn, closeFunc: closeFunc, dialerID: dialerID, instance: instance, metricRecorder: mr, attrs: a, } } // instrumentedConn wraps a net.Conn and invokes closeFunc when the connection // is closed. type instrumentedConn struct { net.Conn closeFunc func() dialerID string instance string metricRecorder telv2.MetricRecorder attrs telv2.Attributes } // Read delegates to the underlying net.Conn interface and records number of // bytes read. func (i *instrumentedConn) Read(b []byte) (int, error) { bytesRead, err := i.Conn.Read(b) if err == nil { go tel.RecordBytesReceived(context.Background(), int64(bytesRead), i.instance, i.dialerID) go i.metricRecorder.RecordBytesRxCount(context.Background(), int64(bytesRead), i.attrs) } return bytesRead, err } // Write delegates to the underlying net.Conn interface and records number of // bytes written. func (i *instrumentedConn) Write(b []byte) (int, error) { bytesWritten, err := i.Conn.Write(b) if err == nil { go tel.RecordBytesSent(context.Background(), int64(bytesWritten), i.instance, i.dialerID) go i.metricRecorder.RecordBytesTxCount(context.Background(), int64(bytesWritten), i.attrs) } return bytesWritten, err } // Close delegates to the underlying net.Conn interface and reports the close // to the provided closeFunc only when Close returns no error. func (i *instrumentedConn) Close() error { err := i.Conn.Close() if err != nil { return err } go i.closeFunc() return nil } // Close closes the Dialer; it prevents the Dialer from refreshing the information // needed to connect. func (d *Dialer) Close() error { // Check if Close has already been called. select { case <-d.closed: return nil default: } close(d.closed) d.lock.Lock() for _, i := range d.cache { _ = i.Close() } d.lock.Unlock() d.metricsMu.Lock() ctx, cancel := context.WithTimeout(context.Background(), metricShutdownTimeout) defer cancel() for _, mr := range d.metricRecorders { // If a metric recorder doesn't shutdown cleanly, log the error and // keep going. An error here isn't actionable and should not be // returned to the caller. if err := mr.Shutdown(ctx); err != nil { d.logger.Debugf(context.Background(), "internal metric exporter failed to shutdown: %v", err) } } d.metricsMu.Unlock() return nil } func (d *Dialer) connectionInfoCache(ctx context.Context, uri alloydb.InstanceURI, mr telv2.MetricRecorder) (monitoredCache, bool, error) { d.lock.RLock() c, ok := d.cache[uri] d.lock.RUnlock() if !ok { d.lock.Lock() defer d.lock.Unlock() // Recheck to ensure instance wasn't created between locks c, ok = d.cache[uri] if !ok { d.logger.Debugf(ctx, "[%v] Connection info added to cache", uri.String()) k, err := d.keyGenerator.rsaKey() if err != nil { return monitoredCache{}, ok, err } var cache connectionInfoCache switch { case d.lazyRefresh: cache = alloydb.NewLazyRefreshCache( uri, d.logger, d.client, k, d.refreshTimeout, d.dialerID, d.disableMetadataExchange, d.userAgent, mr, ) case d.staticConnInfo != nil: var err error cache, err = alloydb.NewStaticConnectionInfoCache( uri, d.logger, d.staticConnInfo, ) if err != nil { return monitoredCache{}, ok, err } default: cache = alloydb.NewRefreshAheadCache( uri, d.logger, d.client, k, d.refreshTimeout, d.dialerID, d.disableMetadataExchange, d.userAgent, mr, ) } var open uint64 c = monitoredCache{openConns: &open, connectionInfoCache: cache} d.cache[uri] = c } } return c, ok, nil }