dialer.go (525 lines of code) (raw):
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlconn
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
_ "embed"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/httptransport"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
const (
// defaultTCPKeepAlive is the default keep alive value used on connections to a Cloud SQL instance.
defaultTCPKeepAlive = 30 * time.Second
// serverProxyPort is the port the server-side proxy receives connections on.
serverProxyPort = "3307"
// iamLoginScope is the OAuth2 scope used for tokens embedded in the ephemeral
// certificate.
iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
// universeDomainEnvVar is the environment variable for setting the default
// service domain for a given Cloud universe.
universeDomainEnvVar = "GOOGLE_CLOUD_UNIVERSE_DOMAIN"
// defaultUniverseDomain is the default value for universe domain.
// Universe domain is the default service domain for a given Cloud universe.
defaultUniverseDomain = "googleapis.com"
)
var (
// ErrDialerClosed is used when a caller invokes Dial after closing the
// Dialer.
ErrDialerClosed = errors.New("cloudsqlconn: dialer is closed")
// versionString indicates the version of this library.
//go:embed version.txt
versionString string
userAgent = "cloud-sql-go-connector/" + strings.TrimSpace(versionString)
)
// keyGenerator encapsulates the details of RSA key generation to provide lazy
// generation, custom keys, or a default RSA generator.
type keyGenerator struct {
once sync.Once
key *rsa.PrivateKey
err error
genFunc func() (*rsa.PrivateKey, error)
}
// newKeyGenerator initializes a keyGenerator that will (in order):
// - always return the RSA key if one is provided, or
// - generate an RSA key lazily when it's requested, or
// - (default) immediately generate an RSA key as part of the initializer.
func newKeyGenerator(
k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error),
) (*keyGenerator, error) {
g := &keyGenerator{genFunc: genFunc}
switch {
case k != nil:
// If the caller has provided a key, initialize the key and consume the
// sync.Once now.
g.once.Do(func() { g.key, g.err = k, nil })
case lazy:
// If lazy refresh is enabled, do nothing and wait for the call to
// rsaKey.
default:
// If no key has been provided and lazy refresh isn't enabled, generate
// the key and consume the sync.Once now.
g.once.Do(func() { g.key, g.err = g.genFunc() })
}
return g, g.err
}
// rsaKey will generate an RSA key if one is not already cached. Otherwise, it
// will return the cached key.
func (g *keyGenerator) rsaKey() (*rsa.PrivateKey, error) {
g.once.Do(func() { g.key, g.err = g.genFunc() })
return g.key, g.err
}
type connectionInfoCache interface {
ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
UpdateRefresh(*bool)
ForceRefresh()
io.Closer
}
type cacheKey struct {
domainName string
project string
region string
name string
}
// getClientUniverseDomain returns the default service domain for a given Cloud
// universe, with the following precedence:
//
// 1. A non-empty option.WithUniverseDomain or similar client option.
// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
// 3. The default value "googleapis.com".
//
// This is the universe domain configured for the client, which will be compared
// to the universe domain that is separately configured for the credentials.
func (c *dialerConfig) getClientUniverseDomain() string {
if c.clientUniverseDomain != "" {
return c.clientUniverseDomain
}
if envUD := os.Getenv(universeDomainEnvVar); envUD != "" {
return envUD
}
return defaultUniverseDomain
}
// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
type Dialer struct {
lock sync.RWMutex
cache map[cacheKey]*monitoredCache
keyGenerator *keyGenerator
refreshTimeout time.Duration
// closed reports if the dialer has been closed.
closed chan struct{}
sqladmin *sqladmin.Service
logger debug.ContextLogger
// lazyRefresh determines what kind of caching is used for ephemeral
// certificates. When lazyRefresh is true, the dialer will use a lazy
// cache, refresh certificates only when a connection attempt needs a fresh
// certificate. Otherwise, a refresh ahead cache will be used. The refresh
// ahead cache assumes a background goroutine may run consistently.
lazyRefresh bool
// defaultDialConfig holds the constructor level DialOptions, so that it
// can be copied and mutated by the Dial function.
defaultDialConfig dialConfig
// dialerID uniquely identifies a Dialer. Used for monitoring purposes,
// *only* when a client has configured OpenCensus exporters.
dialerID string
// dialFunc is the function used to connect to the address on the named
// network. By default, it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)
// iamTokenProvider supplies the OAuth2 token used for IAM DB Authn.
iamTokenProvider auth.TokenProvider
// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver
failoverPeriod time.Duration
}
var (
errUseTokenSource = errors.New("use WithTokenSource when IAM AuthN is not enabled")
errUseIAMTokenSource = errors.New("use WithIAMAuthNTokenSources instead of WithTokenSource be used when IAM AuthN is enabled")
)
type nullLogger struct{}
func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {}
// NewDialer creates a new Dialer.
//
// Initial calls to NewDialer make take longer than normal because generation of an
// RSA keypair is performed. Calls with a WithRSAKeyPair DialOption or after a default
// RSA keypair is generated will be faster.
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: cloudsql.RefreshTimeout,
dialFunc: proxy.Dial,
logger: nullLogger{},
useragents: []string{userAgent},
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
if cfg.err != nil {
return nil, cfg.err
}
}
if cfg.useIAMAuthN && cfg.setTokenSource && !cfg.setIAMAuthNTokenSource {
return nil, errUseIAMTokenSource
}
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
return nil, errUseTokenSource
}
// If callers have not provided a credential source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use
// default credentials
if !cfg.setCredentials {
c, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{sqladmin.SqlserviceAdminScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create default credentials: %v", err)
}
cfg.authCredentials = c
// create second set of credentials, scoped for IAM AuthN login only
scoped, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{iamLoginScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create scoped credentials: %v", err)
}
cfg.iamLoginTokenProvider = scoped.TokenProvider
}
// For all credential paths, use auth library's built-in
// httptransport.NewClient
if cfg.authCredentials != nil {
// Set headers for auth client as below WithHTTPClient will ignore
// WithQuotaProject and WithUserAgent Options
headers := http.Header{}
headers.Set("User-Agent", strings.Join(cfg.useragents, " "))
if cfg.quotaProject != "" {
headers.Set("X-Goog-User-Project", cfg.quotaProject)
}
authClient, err := httptransport.NewClient(&httptransport.Options{
Headers: headers,
Credentials: cfg.authCredentials,
UniverseDomain: cfg.getClientUniverseDomain(),
})
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
// If callers have not provided an HTTPClient explicitly with
// WithHTTPClient, then use auth client
if !cfg.setHTTPClient {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient))
}
} else {
// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
if cfg.quotaProject != "" {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithQuotaProject(cfg.quotaProject))
}
}
client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
}
dc := dialConfig{
ipType: cloudsql.PublicIP,
tcpKeepAlive: defaultTCPKeepAlive,
useIAMAuthN: cfg.useIAMAuthN,
}
for _, opt := range cfg.dialOpts {
opt(&dc)
}
if err := trace.InitMetrics(); err != nil {
return nil, err
}
g, err := newKeyGenerator(cfg.rsaKey, cfg.lazyRefresh,
func() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 2048)
})
if err != nil {
return nil, err
}
var r instance.ConnectionNameResolver = cloudsql.DefaultResolver
if cfg.resolver != nil {
r = cfg.resolver
}
d := &Dialer{
closed: make(chan struct{}),
cache: make(map[cacheKey]*monitoredCache),
lazyRefresh: cfg.lazyRefresh,
keyGenerator: g,
refreshTimeout: cfg.refreshTimeout,
sqladmin: client,
logger: cfg.logger,
defaultDialConfig: dc,
dialerID: uuid.New().String(),
iamTokenProvider: cfg.iamLoginTokenProvider,
dialFunc: cfg.dialFunc,
resolver: r,
failoverPeriod: cfg.failoverPeriod,
}
return d, nil
}
// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
// icn argument must be the instance's connection name, which is in the format
// "project-name:region:instance-name".
func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn net.Conn, err error) {
select {
case <-d.closed:
return nil, ErrDialerClosed
default:
}
startTime := time.Now()
var endDial trace.EndSpanFunc
ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial",
trace.AddInstanceName(icn),
trace.AddDialerID(d.dialerID),
)
defer func() {
go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
endDial(err)
}()
cn, err := d.resolver.Resolve(ctx, icn)
if err != nil {
return nil, err
}
// Log if resolver changed the instance name input string.
if cn.String() != icn {
d.logger.Debugf(ctx, "resolved instance %s to %s", icn, cn)
}
cfg := d.defaultDialConfig
for _, opt := range opts {
opt(&cfg)
}
var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
c, err := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
if err != nil {
endInfo(err)
return nil, err
}
ci, err := c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
endInfo(err)
return nil, err
}
endInfo(err)
// If the client certificate has expired (as when the computer goes to
// sleep, and the refresh cycle cannot run), force a refresh immediately.
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
if !validClientCert(ctx, cn, d.logger, ci.Expiration) {
d.logger.Debugf(ctx, "[%v] Refreshing certificate now", cn.String())
c.ForceRefresh()
// Block on refreshed connection info
ci, err = c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
return nil, err
}
}
var connectEnd trace.EndSpanFunc
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
defer func() { connectEnd(err) }()
addr, err := ci.Addr(cfg.ipType)
if err != nil {
d.removeCached(ctx, cn, c, err)
return nil, err
}
addr = net.JoinHostPort(addr, serverProxyPort)
f := d.dialFunc
if cfg.dialFunc != nil {
f = cfg.dialFunc
}
d.logger.Debugf(ctx, "[%v] Dialing %v", cn.String(), addr)
conn, err = f(ctx, "tcp", addr)
if err != nil {
d.logger.Debugf(ctx, "[%v] Dialing %v failed: %v", cn.String(), addr, err)
// refresh the instance info in case it caused the connection failure
c.ForceRefresh()
return nil, errtype.NewDialError("failed to dial", cn.String(), err)
}
if c, ok := conn.(*net.TCPConn); ok {
if err := c.SetKeepAlive(true); err != nil {
return nil, errtype.NewDialError("failed to set keep-alive", cn.String(), err)
}
if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil {
return nil, errtype.NewDialError("failed to set keep-alive period", cn.String(), err)
}
}
tlsConn := tls.Client(conn, ci.TLSConfig())
err = tlsConn.HandshakeContext(ctx)
if err != nil {
// TLS handshake errors are fatal and require a refresh. Remove the instance
// from the cache so that future calls to Dial() will block until the
// certificate is refreshed successfully.
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
d.removeCached(ctx, cn, c, err)
_ = tlsConn.Close() // best effort close attempt
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
}
latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(c.openConnsCount, 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
}()
closeFunc := func() {
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
}
errFunc := func(err error) {
// io.EOF occurs when the server closes the connection. This is safe to
// ignore.
if err == io.EOF {
return
}
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
if d.isTLSError(err) {
// TLS handshake errors are fatal. Remove the instance from the cache
// so that future calls to Dial() will block until the certificate
// is refreshed successfully.
d.removeCached(ctx, cn, c, err)
_ = tlsConn.Close() // best effort close attempt
}
}
iConn := newInstrumentedConn(tlsConn, closeFunc, errFunc, d.dialerID, cn.String())
// If this connection was opened using a Domain Name, then store it for later
// in case it needs to be forcibly closed.
if cn.HasDomainName() {
c.mu.Lock()
c.openConns = append(c.openConns, iConn)
c.mu.Unlock()
}
return iConn, nil
}
func (d *Dialer) isTLSError(err error) bool {
if nErr, ok := err.(net.Error); ok {
return !nErr.Timeout() && // it's a permanent net error
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
}
return false
}
// removeCached stops all background refreshes, closes open sockets, and deletes
// the cache entry.
func (d *Dialer) removeCached(
ctx context.Context,
i instance.ConnName, c *monitoredCache, err error,
) {
d.logger.Debugf(
ctx,
"[%v] Removing connection info from cache: %v",
i.String(),
err,
)
// If this instance of monitoredCache is still in the cache, remove it.
// If this instance was already removed from the cache or
// if *a separate goroutine* replaced it with a new instance, do nothing.
key := createKey(i)
d.lock.Lock()
if cachedC, ok := d.cache[key]; ok && cachedC == c {
delete(d.cache, key)
}
d.lock.Unlock()
// Close the monitoredCache, this call is idempotent.
c.Close()
}
// validClientCert checks that the ephemeral client certificate retrieved from
// the cache is unexpired. The time comparisons strip the monotonic clock value
// to ensure an accurate result, even after laptop sleep.
func validClientCert(
ctx context.Context, cn instance.ConnName,
l debug.ContextLogger, expiration time.Time,
) bool {
// Use UTC() to strip monotonic clock value to guard against inaccurate
// comparisons, especially after laptop sleep.
// See the comments on the monotonic clock in the Go documentation for
// details: https://pkg.go.dev/time#hdr-Monotonic_Clocks
now := time.Now().UTC()
valid := expiration.UTC().After(now)
l.Debugf(
ctx,
"[%v] Now = %v, Current cert expiration = %v",
cn.String(),
now.Format(time.RFC3339),
expiration.UTC().Format(time.RFC3339),
)
l.Debugf(ctx, "[%v] Cert is valid = %v", cn.String(), valid)
return valid
}
// EngineVersion returns the engine type and version for the instance
// connection name. The value will correspond to one of the following types for
// the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) {
cn, err := d.resolver.Resolve(ctx, icn)
if err != nil {
return "", err
}
c, err := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN)
if err != nil {
return "", err
}
ci, err := c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
return "", err
}
return ci.DBVersion, nil
}
// Warmup starts the background refresh necessary to connect to the instance.
// Use Warmup to start the refresh process early if you don't know when you'll
// need to call "Dial".
func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error {
cn, err := d.resolver.Resolve(ctx, icn)
if err != nil {
return err
}
cfg := d.defaultDialConfig
for _, opt := range opts {
opt(&cfg)
}
c, err := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
if err != nil {
return err
}
_, err = c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c, err)
}
return err
}
// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
return &instrumentedConn{
Conn: conn,
closeFunc: closeFunc,
errFunc: errFunc,
dialerID: dialerID,
connName: connName,
}
}
// instrumentedConn wraps a net.Conn and invokes closeFunc when the connection
// is closed.
type instrumentedConn struct {
net.Conn
closeFunc func()
errFunc func(error)
mu sync.RWMutex
closed bool
dialerID string
connName string
}
// Read delegates to the underlying net.Conn interface and records number of
// bytes read
func (i *instrumentedConn) Read(b []byte) (int, error) {
bytesRead, err := i.Conn.Read(b)
if err == nil {
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
} else {
i.errFunc(err)
}
return bytesRead, err
}
// Write delegates to the underlying net.Conn interface and records number of
// bytes written
func (i *instrumentedConn) Write(b []byte) (int, error) {
bytesWritten, err := i.Conn.Write(b)
if err == nil {
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
} else {
i.errFunc(err)
}
return bytesWritten, err
}
// isClosed returns true if this connection is closing or is already closed.
func (i *instrumentedConn) isClosed() bool {
i.mu.RLock()
defer i.mu.RUnlock()
return i.closed
}
// Close delegates to the underlying net.Conn interface and reports the close
// to the provided closeFunc only when Close returns no error.
func (i *instrumentedConn) Close() error {
i.mu.Lock()
defer i.mu.Unlock()
i.closed = true
err := i.Conn.Close()
if err != nil {
return err
}
go i.closeFunc()
return nil
}
// Close closes the Dialer; it prevents the Dialer from refreshing the information
// needed to connect.
func (d *Dialer) Close() error {
// Check if Close has already been called.
select {
case <-d.closed:
return nil
default:
}
close(d.closed)
d.lock.Lock()
defer d.lock.Unlock()
for _, i := range d.cache {
i.Close()
}
return nil
}
// createKey creates a key for the cache from an instance.ConnName.
// An instance.ConnName uniquely identifies a connection using
// project:region:instance + domainName. However, in the dialer cache,
// we want to to identify entries either by project:region:instance, or
// by domainName, but not the combination of the two.
func createKey(cn instance.ConnName) cacheKey {
if cn.HasDomainName() {
return cacheKey{domainName: cn.DomainName()}
}
return cacheKey{
name: cn.Name(),
project: cn.Project(),
region: cn.Region(),
}
}
// connectionInfoCache is a helper function for returning the appropriate
// connection info Cache in a threadsafe way. It will create a new cache,
// modify the existing one, or leave it unchanged as needed.
func (d *Dialer) connectionInfoCache(
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
) (*monitoredCache, error) {
k := createKey(cn)
d.lock.RLock()
c, ok := d.cache[k]
d.lock.RUnlock()
if ok && !c.isClosed() {
c.UpdateRefresh(useIAMAuthN)
return c, nil
}
d.lock.Lock()
defer d.lock.Unlock()
// Recheck to ensure instance wasn't created or changed between locks
c, ok = d.cache[k]
// c exists and is not closed
if ok && !c.isClosed() {
c.UpdateRefresh(useIAMAuthN)
return c, nil
}
// Create a new instance of monitoredCache
var useIAMAuthNDial bool
if useIAMAuthN != nil {
useIAMAuthNDial = *useIAMAuthN
}
d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
rsaKey, err := d.keyGenerator.rsaKey()
if err != nil {
return nil, err
}
var cache connectionInfoCache
if d.lazyRefresh {
cache = cloudsql.NewLazyRefreshCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
}
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
d.cache[k] = c
return c, nil
}