internal/alloydb/refresh.go (265 lines of code) (raw):
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package alloydb
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
alloydbadmin "cloud.google.com/go/alloydb/apiv1alpha"
"cloud.google.com/go/alloydb/apiv1alpha/alloydbpb"
"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/tel"
"google.golang.org/protobuf/types/known/durationpb"
)
const (
// PublicIP is the value for public IP connections.
PublicIP = "PUBLIC"
// PrivateIP is the value for private IP connections.
PrivateIP = "PRIVATE"
// PSC designates PSC-based connections.
PSC = "PSC"
)
type instanceInfo struct {
// ipAddrs is the instance's IP addresses
ipAddrs map[string]string
// uid is the instance UID
uid string
}
// fetchInstanceInfo uses the AlloyDB Admin APIs get method to retrieve the
// information about an AlloyDB instance that is used to create secure
// connections.
func fetchInstanceInfo(
ctx context.Context, cl *alloydbadmin.AlloyDBAdminClient, inst InstanceURI,
) (i instanceInfo, err error) {
var end tel.EndSpanFunc
ctx, end = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.FetchMetadata")
defer func() { end(err) }()
req := &alloydbpb.GetConnectionInfoRequest{
Parent: fmt.Sprintf(
"projects/%s/locations/%s/clusters/%s/instances/%s",
inst.project, inst.region, inst.cluster, inst.name,
),
}
resp, err := cl.GetConnectionInfo(ctx, req)
if err != nil {
return instanceInfo{}, errtype.NewRefreshError(
"failed to get instance metadata", inst.String(), err,
)
}
// parse any ip addresses that might be used to connect
ipAddrs := make(map[string]string)
if addr := resp.GetIpAddress(); addr != "" {
ipAddrs[PrivateIP] = addr
}
if addr := resp.GetPublicIpAddress(); addr != "" {
ipAddrs[PublicIP] = addr
}
if addr := resp.GetPscDnsName(); addr != "" {
ipAddrs[PSC] = addr
}
if len(ipAddrs) == 0 {
return instanceInfo{}, errtype.NewConfigError(
"cannot connect to instance - it has no supported IP addresses",
inst.String(),
)
}
return instanceInfo{ipAddrs: ipAddrs, uid: resp.InstanceUid}, nil
}
var errInvalidPEM = errors.New("certificate is not a valid PEM")
func parseCert(cert string) (*x509.Certificate, error) {
b, _ := pem.Decode([]byte(cert))
if b == nil {
return nil, errInvalidPEM
}
return x509.ParseCertificate(b.Bytes)
}
type clientCertificate struct {
// certChain is the client certificate chained with the intermediate
// cert(s) and CA cert.
certChain tls.Certificate
// ca cert is the CA certificate of the cluster
caCert *x509.Certificate
// expiry is the expiration of the client certificate.
expiry time.Time
}
// fetchClientCertificate uses the AlloyDB Admin API's
// generateClientCertificate method to create a signed TLS certificate that
// authorized to connect via the AlloyDB instance's serverside proxy. The cert
// is valid for one hour.
func fetchClientCertificate(
ctx context.Context,
cl *alloydbadmin.AlloyDBAdminClient,
inst InstanceURI,
key *rsa.PrivateKey,
disableMetadataExchange bool,
) (cc *clientCertificate, err error) {
var end tel.EndSpanFunc
ctx, end = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.FetchEphemeralCert")
defer func() { end(err) }()
buf := &bytes.Buffer{}
k := x509.MarshalPKCS1PublicKey(&key.PublicKey)
err = pem.Encode(buf, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: k})
if err != nil {
return nil, err
}
req := &alloydbpb.GenerateClientCertificateRequest{
Parent: fmt.Sprintf(
"projects/%s/locations/%s/clusters/%s", inst.project, inst.region, inst.cluster,
),
PublicKey: buf.String(),
CertDuration: durationpb.New(time.Second * 3600),
UseMetadataExchange: !disableMetadataExchange,
}
resp, err := cl.GenerateClientCertificate(ctx, req)
if err != nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
keyPEMBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
keyPEM := pem.EncodeToMemory(keyPEMBlock)
return newClientCertificate(
inst, keyPEM, resp.PemCertificateChain, resp.CaCert,
)
}
func newClientCertificate(
inst InstanceURI,
keyPEM []byte,
chain []string,
caCertRaw string,
) (cc *clientCertificate, err error) {
certPEMBlock := []byte(strings.Join(chain, "\n"))
cert, err := tls.X509KeyPair(certPEMBlock, keyPEM)
if err != nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
caCertPEMBlock, _ := pem.Decode([]byte(caCertRaw))
if caCertPEMBlock == nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
errors.New("no PEM data found in the ca cert"),
)
}
caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes)
if err != nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
// Extract expiry from client certificate.
clientCertPEMBlock, _ := pem.Decode([]byte(chain[0]))
if clientCertPEMBlock == nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
errors.New("no PEM data found in the client cert"),
)
}
clientCert, err := x509.ParseCertificate(clientCertPEMBlock.Bytes)
if err != nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
// Save the parsed certificate as the leaf certificate, to avoid additional
// parsing costs as part of the TLS connection.
cert.Leaf = clientCert
return &clientCertificate{
certChain: cert,
caCert: caCert,
expiry: clientCert.NotAfter,
}, nil
}
func newAdminAPIClient(
client *alloydbadmin.AlloyDBAdminClient,
key *rsa.PrivateKey,
dialerID string,
disableMetadataExchange bool,
) adminAPIClient {
return adminAPIClient{
client: client,
key: key,
dialerID: dialerID,
disableMetadataExchange: disableMetadataExchange,
}
}
// adminAPIClient manages the AlloyDB Admin API access to instance metadata and
// to ephemeral certificates.
type adminAPIClient struct {
// client provides access to the AlloyDB Admin API
client *alloydbadmin.AlloyDBAdminClient
// key is used to request client certificates
key *rsa.PrivateKey
// dialerID is the unique ID of the associated dialer.
dialerID string
// disableMetadataExchange is a temporary addition to ease the migration to
// when the metadata exchange is required.
disableMetadataExchange bool
}
// ConnectionInfo holds all the data necessary to connect to an instance.
type ConnectionInfo struct {
Instance InstanceURI
IPAddrs map[string]string
ClientCert tls.Certificate
RootCAs *x509.CertPool
Expiration time.Time
}
func (c adminAPIClient) connectionInfo(
ctx context.Context, i InstanceURI,
) (res ConnectionInfo, err error) {
var refreshEnd tel.EndSpanFunc
ctx, refreshEnd = tel.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.RefreshConnection",
tel.AddInstanceName(i.String()),
)
defer func() {
go tel.RecordRefreshResult(
context.Background(), i.String(), c.dialerID, err,
)
refreshEnd(err)
}()
type mdRes struct {
info instanceInfo
err error
}
mdCh := make(chan mdRes, 1)
go func() {
defer close(mdCh)
c, err := fetchInstanceInfo(ctx, c.client, i)
mdCh <- mdRes{info: c, err: err}
}()
type certRes struct {
cc *clientCertificate
err error
}
certCh := make(chan certRes, 1)
go func() {
defer close(certCh)
cc, err := fetchClientCertificate(ctx, c.client, i, c.key, c.disableMetadataExchange)
certCh <- certRes{cc: cc, err: err}
}()
var info instanceInfo
select {
case r := <-mdCh:
if r.err != nil {
return ConnectionInfo{}, fmt.Errorf(
"failed to get instance IP address: %w", r.err,
)
}
info = r.info
case <-ctx.Done():
return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
var cc *clientCertificate
select {
case r := <-certCh:
if r.err != nil {
return ConnectionInfo{}, fmt.Errorf(
"fetch ephemeral cert failed: %w", r.err,
)
}
cc = r.cc
case <-ctx.Done():
return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}
caCerts := x509.NewCertPool()
caCerts.AddCert(cc.caCert)
ci := ConnectionInfo{
Instance: i,
IPAddrs: info.ipAddrs,
ClientCert: cc.certChain,
RootCAs: caCerts,
Expiration: cc.expiry,
}
return ci, nil
}