internal/mock/alloydb.go (294 lines of code) (raw):
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"encoding/pem"
"fmt"
"math/big"
"net"
"testing"
"time"
"cloud.google.com/go/alloydb/connectors/apiv1alpha/connectorspb"
"google.golang.org/protobuf/proto"
)
// Option configures a FakeAlloyDBInstance
type Option func(*FakeAlloyDBInstance)
// WithPublicIP sets the public IP address to addr.
func WithPublicIP(addr string) Option {
return func(f *FakeAlloyDBInstance) {
f.ipAddrs["PUBLIC"] = addr
}
}
// WithPrivateIP sets the private IP address to addr.
func WithPrivateIP(addr string) Option {
return func(f *FakeAlloyDBInstance) {
f.ipAddrs["PRIVATE"] = addr
}
}
// WithPSC sets the PSC address to addr.
func WithPSC(addr string) Option {
return func(f *FakeAlloyDBInstance) {
f.ipAddrs["PSC"] = addr
}
}
// WithServerName sets the name that server uses to identify itself in the TLS
// handshake.
func WithServerName(name string) Option {
return func(f *FakeAlloyDBInstance) {
f.serverName = name
}
}
// WithCertExpiry sets the expiration time of the fake instance
func WithCertExpiry(expiry time.Time) Option {
return func(f *FakeAlloyDBInstance) {
f.certExpiry = expiry
}
}
// FakeAlloyDBInstance represents the server side proxy.
type FakeAlloyDBInstance struct {
project string
region string
cluster string
name string
// ipAddrs is a map of IP type (PUBLIC or PRIVATE) to IP address.
ipAddrs map[string]string
uid string
serverName string
certExpiry time.Time
rootCACert *x509.Certificate
rootKey *rsa.PrivateKey
intermedCert *x509.Certificate
intermedKey *rsa.PrivateKey
serverCert *x509.Certificate
serverKey *rsa.PrivateKey
}
// String returns the URI of the instance.
func (f FakeAlloyDBInstance) String() string {
return fmt.Sprintf(
"projects/%v/locations/%v/clusters/%v/instances/%v",
f.project,
f.region,
f.cluster,
f.name,
)
}
func mustGenerateKey() *rsa.PrivateKey {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
return key
}
var (
rootCAKey = mustGenerateKey()
intermedCAKey = mustGenerateKey()
serverKey = mustGenerateKey()
)
// NewFakeInstance creates a Fake AlloyDB instance.
func NewFakeInstance(proj, reg, clust, name string, opts ...Option) FakeAlloyDBInstance {
f := FakeAlloyDBInstance{
project: proj,
region: reg,
cluster: clust,
name: name,
ipAddrs: map[string]string{"PRIVATE": "127.0.0.1"},
uid: "00000000-0000-0000-0000-000000000000",
serverName: "00000000-0000-0000-0000-000000000000.server.alloydb",
certExpiry: time.Now().Add(24 * time.Hour),
}
for _, o := range opts {
o(&f)
}
rootTemplate := &x509.Certificate{
SerialNumber: &big.Int{},
Subject: pkix.Name{
CommonName: "root.alloydb",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
}
// create a self-signed root certificate
signedRoot, err := x509.CreateCertificate(
rand.Reader, rootTemplate, rootTemplate, &rootCAKey.PublicKey, rootCAKey)
if err != nil {
panic(err)
}
rootCert, err := x509.ParseCertificate(signedRoot)
if err != nil {
panic(err)
}
// create an intermediate CA, signed by the root
// This CA signs all client certs.
intermedTemplate := &x509.Certificate{
SerialNumber: &big.Int{},
Subject: pkix.Name{
CommonName: "client.alloydb",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
}
signedIntermed, err := x509.CreateCertificate(
rand.Reader, intermedTemplate, rootCert, &intermedCAKey.PublicKey, rootCAKey)
if err != nil {
panic(err)
}
intermedCert, err := x509.ParseCertificate(signedIntermed)
if err != nil {
panic(err)
}
// create a server certificate, signed by the root
// This is what the server side proxy uses.
serverTemplate := &x509.Certificate{
SerialNumber: &big.Int{},
Subject: pkix.Name{
CommonName: f.serverName,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
}
signedServer, err := x509.CreateCertificate(
rand.Reader, serverTemplate, rootCert, &serverKey.PublicKey, rootCAKey)
if err != nil {
panic(err)
}
serverCert, err := x509.ParseCertificate(signedServer)
if err != nil {
panic(err)
}
// save all TLS certificates for later use.
f.rootCACert = rootCert
f.rootKey = rootCAKey
f.intermedCert = intermedCert
f.intermedKey = intermedCAKey
f.serverCert = serverCert
f.serverKey = serverKey
return f
}
// GeneratePEMCertificateChain produces the certificate chain including an
// ephemeral client certificate.
func (f *FakeAlloyDBInstance) GeneratePEMCertificateChain(
pub *rsa.PublicKey,
) ([]string, error) {
template := &x509.Certificate{
PublicKey: pub,
SerialNumber: &big.Int{},
Issuer: f.intermedCert.Subject,
NotBefore: time.Now(),
NotAfter: f.certExpiry,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
cert, err := x509.CreateCertificate(
rand.Reader, template, f.intermedCert,
template.PublicKey, f.intermedKey,
)
if err != nil {
return nil, err
}
certPEM := &bytes.Buffer{}
pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: cert})
instancePEM := &bytes.Buffer{}
pem.Encode(
instancePEM, &pem.Block{Type: "CERTIFICATE", Bytes: f.intermedCert.Raw},
)
caPEM := &bytes.Buffer{}
pem.Encode(caPEM, &pem.Block{Type: "CERTIFICATE", Bytes: f.rootCACert.Raw})
return []string{certPEM.String(), instancePEM.String(), caPEM.String()}, nil
}
// StartServerProxy starts a fake server proxy and listens on the provided port
// on all interfaces, configured with TLS as specified by the
// FakeAlloyDBInstance. Callers should invoke the returned function to clean up
// all resources.
func StartServerProxy(t *testing.T, inst FakeAlloyDBInstance) func() {
pool := x509.NewCertPool()
pool.AddCert(inst.rootCACert)
tryListen := func(t *testing.T, attempts int) net.Listener {
var (
ln net.Listener
err error
)
for i := 0; i < attempts; i++ {
ln, err = tls.Listen("tcp", ":5433", &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{inst.serverCert.Raw, inst.rootCACert.Raw},
PrivateKey: inst.serverKey,
Leaf: inst.serverCert,
},
},
ServerName: "127.0.0.1",
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
})
if err != nil {
t.Log("listener failed to start, waiting 100ms")
time.Sleep(500 * time.Millisecond)
continue
}
return ln
}
t.Fatalf("failed to start listener: %v", err)
return nil
}
ln := tryListen(t, 10)
ctx, cancel := context.WithCancel(context.Background())
go func() {
for {
select {
case <-ctx.Done():
return
default:
conn, err := ln.Accept()
if err != nil {
return
}
if err := metadataExchange(conn); err != nil {
conn.Close()
return
}
// Database protocol takes over from here.
conn.Write([]byte(inst.name))
conn.Close()
}
}
}()
return func() {
cancel()
ln.Close()
}
}
// metadataExchange mimics server side behavior in four steps:
//
// 1. Read a big endian uint32 (4 bytes) from the client. This is the number of
// bytes the message consumes. The length does not include the initial four
// bytes.
//
// 2. Read the message from the client using the message length and unmarshal
// it into a MetadataExchangeResponse message.
//
// The real server implementation will then validate the client has connection
// permissions using the provided OAuth2 token based on the auth type. Here in
// the test implementation, the server does nothing.
//
// 3. Prepare a response and write the size of the response as a uint32 (4
// bytes)
//
// 4. Marshal the response to bytes and write those to the client as well.
//
// Subsequent interactions with the test server use the database protocol.
func metadataExchange(conn net.Conn) error {
msgSize := make([]byte, 4)
n, err := conn.Read(msgSize)
if err != nil {
return err
}
if n != 4 {
return fmt.Errorf("read %d bytes, want = 4", n)
}
size := binary.BigEndian.Uint32(msgSize)
buf := make([]byte, size)
n, err = conn.Read(buf)
if err != nil {
return err
}
if n != int(size) {
return fmt.Errorf("read %d bytes, want = %d", n, size)
}
m := &connectorspb.MetadataExchangeRequest{}
err = proto.Unmarshal(buf, m)
if err != nil {
return err
}
resp := &connectorspb.MetadataExchangeResponse{
ResponseCode: connectorspb.MetadataExchangeResponse_OK,
}
data, err := proto.Marshal(resp)
if err != nil {
return err
}
respSize := proto.Size(resp)
buf = make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(respSize))
buf = append(buf, data...)
n, err = conn.Write(buf)
if err != nil {
return err
}
if n != len(buf) {
return fmt.Errorf("write %d bytes, want = %d", n, len(buf))
}
return nil
}