internal/cloudsql/refresh.go (284 lines of code) (raw):
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsql
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"cloud.google.com/go/auth"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/trace"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
const (
// PublicIP is the value for public IP Cloud SQL instances.
PublicIP = "PUBLIC"
// PrivateIP is the value for private IP Cloud SQL instances.
PrivateIP = "PRIVATE"
// PSC is the value for private service connect Cloud SQL instances.
PSC = "PSC"
// AutoIP selects public IP if available and otherwise selects private
// IP.
AutoIP = "AutoIP"
)
// metadata contains information about a Cloud SQL instance needed to create
// connections.
type metadata struct {
ipAddrs map[string]string
serverCACert []*x509.Certificate
serverCAMode string
dnsName string
version string
}
// fetchMetadata uses the Cloud SQL Admin APIs get method to retrieve the
// information about a Cloud SQL instance that is used to create secure
// connections.
func fetchMetadata(
ctx context.Context, client *sqladmin.Service, inst instance.ConnName,
) (m metadata, err error) {
var end trace.EndSpanFunc
ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchMetadata")
defer func() { end(err) }()
db, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.ConnectSettings, error) {
return client.Connect.Get(
inst.Project(), inst.Name(),
).Context(ctx2).Do()
}, exponentialBackoff)
if err != nil {
return metadata{}, errtype.NewRefreshError("failed to get instance metadata", inst.String(), err)
}
// validate the instance is supported for authenticated connections
if db.Region != inst.Region() {
msg := fmt.Sprintf(
"provided region was mismatched - got %s, want %s",
inst.Region(), db.Region,
)
return metadata{}, errtype.NewConfigError(msg, inst.String())
}
if db.BackendType != "SECOND_GEN" {
return metadata{}, errtype.NewConfigError(
"unsupported instance - only Second Generation instances are supported",
inst.String(),
)
}
// parse any ip addresses that might be used to connect
ipAddrs := make(map[string]string)
for _, ip := range db.IpAddresses {
switch ip.Type {
case "PRIMARY":
ipAddrs[PublicIP] = ip.IpAddress
case "PRIVATE":
ipAddrs[PrivateIP] = ip.IpAddress
}
}
// resolve DnsName into IP address for PSC
// Note that we have to check for PSC enablement first because CAS instances also set the DnsName.
if db.PscEnabled {
// Search the dns_names field for the PSC DNS Name.
pscDNSName := ""
for _, dnm := range db.DnsNames {
if dnm.Name != "" &&
dnm.ConnectionType == "PRIVATE_SERVICE_CONNECT" && dnm.DnsScope == "INSTANCE" {
pscDNSName = dnm.Name
break
}
}
// If the psc dns name was not found, use the legacy dns_name field
if pscDNSName == "" && db.DnsName != "" {
pscDNSName = db.DnsName
}
// If the psc dns name was found, add it to the ipaddrs map.
if pscDNSName != "" {
ipAddrs[PSC] = pscDNSName
}
}
if len(ipAddrs) == 0 {
return metadata{}, errtype.NewConfigError(
"cannot connect to instance - it has no supported IP addresses",
inst.String(),
)
}
// parse the server-side CA certificate
caCerts := []*x509.Certificate{}
for b, rest := pem.Decode([]byte(db.ServerCaCert.Cert)); b != nil; b, rest = pem.Decode(rest) {
if b == nil {
return metadata{}, errtype.NewRefreshError("failed to decode valid PEM cert", inst.String(), nil)
}
caCert, err := x509.ParseCertificate(b.Bytes)
if err != nil {
return metadata{}, errtype.NewRefreshError(
fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
inst.String(),
nil,
)
}
caCerts = append(caCerts, caCert)
}
// Find a DNS name to use to validate the certificate from the dns_names field. Any
// name in the list may be used to validate the server TLS certificate.
// Fall back to legacy dns_name field if necessary.
var serverName string
if len(db.DnsNames) > 0 {
serverName = db.DnsNames[0].Name
}
if serverName == "" {
serverName = db.DnsName
}
m = metadata{
ipAddrs: ipAddrs,
serverCACert: caCerts,
version: db.DatabaseVersion,
dnsName: serverName,
serverCAMode: db.ServerCaMode,
}
return m, nil
}
// fetchEphemeralCert uses the Cloud SQL Admin API's createEphemeral method to
// create a signed TLS certificate that authorized to connect via the Cloud SQL
// instance's serverside proxy. The cert if valid for approximately one hour.
func fetchEphemeralCert(
ctx context.Context,
client *sqladmin.Service,
inst instance.ConnName,
key *rsa.PrivateKey,
tp auth.TokenProvider,
) (c tls.Certificate, err error) {
var end trace.EndSpanFunc
ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchEphemeralCert")
defer func() { end(err) }()
clientPubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
if err != nil {
return tls.Certificate{}, err
}
req := sqladmin.GenerateEphemeralCertRequest{
PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: clientPubKey, Type: "RSA PUBLIC KEY"})),
}
var tok *auth.Token
if tp != nil {
var tokErr error
tok, tokErr = tp.Token(ctx)
if tokErr != nil {
return tls.Certificate{}, errtype.NewRefreshError(
"failed to retrieve Oauth2 token",
inst.String(),
tokErr,
)
}
req.AccessToken = tok.Value
}
resp, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.GenerateEphemeralCertResponse, error) {
return client.Connect.GenerateEphemeralCert(
inst.Project(), inst.Name(), &req,
).Context(ctx2).Do()
}, exponentialBackoff)
if err != nil {
return tls.Certificate{}, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
// parse the client cert
b, _ := pem.Decode([]byte(resp.EphemeralCert.Cert))
if b == nil {
return tls.Certificate{}, errtype.NewRefreshError(
"failed to decode valid PEM cert",
inst.String(),
nil,
)
}
clientCert, err := x509.ParseCertificate(b.Bytes)
if err != nil {
return tls.Certificate{}, errtype.NewRefreshError(
fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
inst.String(),
nil,
)
}
if tp != nil {
// Adjust the certificate's expiration to be the earliest of
// the token's expiration or the certificate's expiration.
if tok.Expiry.Before(clientCert.NotAfter) {
clientCert.NotAfter = tok.Expiry
}
}
c = tls.Certificate{
Certificate: [][]byte{clientCert.Raw},
PrivateKey: key,
Leaf: clientCert,
}
return c, nil
}
// newAdminAPIClient creates a Refresher.
func newAdminAPIClient(
l debug.ContextLogger,
svc *sqladmin.Service,
key *rsa.PrivateKey,
tp auth.TokenProvider,
dialerID string,
) adminAPIClient {
return adminAPIClient{
dialerID: dialerID,
logger: l,
key: key,
client: svc,
tp: tp,
}
}
// adminAPIClient manages the SQL Admin API access to instance metadata and to
// ephemeral certificates.
type adminAPIClient struct {
// dialerID is the unique ID of the associated dialer.
dialerID string
logger debug.ContextLogger
// key is used to generate the client certificate
key *rsa.PrivateKey
client *sqladmin.Service
// tp is the TokenProvider used for IAM DB AuthN.
tp auth.TokenProvider
}
// ConnectionInfo immediately performs a full refresh operation using the Cloud
// SQL Admin API.
func (c adminAPIClient) ConnectionInfo(
ctx context.Context, cn instance.ConnName, iamAuthNDial bool,
) (ci ConnectionInfo, err error) {
var refreshEnd trace.EndSpanFunc
ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection",
trace.AddInstanceName(cn.String()),
)
defer func() {
go trace.RecordRefreshResult(context.Background(), cn.String(), c.dialerID, err)
refreshEnd(err)
}()
// start async fetching the instance's metadata
type mdRes struct {
md metadata
err error
}
mdC := make(chan mdRes, 1)
go func() {
defer close(mdC)
md, err := fetchMetadata(ctx, c.client, cn)
mdC <- mdRes{md, err}
}()
// start async fetching the certs
type ecRes struct {
ec tls.Certificate
err error
}
ecC := make(chan ecRes, 1)
go func() {
defer close(ecC)
var iamTP auth.TokenProvider
if iamAuthNDial {
iamTP = c.tp
}
ec, err := fetchEphemeralCert(ctx, c.client, cn, c.key, iamTP)
ecC <- ecRes{ec, err}
}()
// wait for the results of each operation
var md metadata
select {
case r := <-mdC:
if r.err != nil {
return ConnectionInfo{}, fmt.Errorf("failed to get instance: %w", r.err)
}
md = r.md
case <-ctx.Done():
return ci, fmt.Errorf("refresh failed: %w", ctx.Err())
}
if iamAuthNDial {
if vErr := supportsAutoIAMAuthN(md.version); vErr != nil {
return ConnectionInfo{}, vErr
}
}
var ec tls.Certificate
select {
case r := <-ecC:
if r.err != nil {
return ConnectionInfo{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
}
ec = r.ec
case <-ctx.Done():
return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
return NewConnectionInfo(
cn, md.dnsName, md.serverCAMode, md.version, md.ipAddrs, md.serverCACert, ec,
), nil
}
// supportsAutoIAMAuthN checks that the engine support automatic IAM authn. If
// auto IAM authn was not request, this is a no-op.
func supportsAutoIAMAuthN(version string) error {
switch {
case strings.HasPrefix(version, "POSTGRES"):
return nil
case strings.HasPrefix(version, "MYSQL"):
return nil
default:
return fmt.Errorf("%s does not support Auto IAM DB Authentication", version)
}
}