internal/mock/cloudsql.go (238 lines of code) (raw):
// Copyright 2021 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// https://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net"
"testing"
"time"
"golang.org/x/oauth2"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
// EmptyTokenSource is an Oauth2.TokenSource that returns empty tokens.
type EmptyTokenSource struct{}
// Token provides an empty oauth2.Token.
func (EmptyTokenSource) Token() (*oauth2.Token, error) {
return &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, nil
}
// FakeCSQLInstance represents settings for a specific Cloud SQL instance.
//
// Use NewFakeCSQLInstance to instantiate.
type FakeCSQLInstance struct {
project string
region string
name string
dbVersion string
// ipAddrs is a map of IP type (PUBLIC or PRIVATE) to IP address.
ipAddrs map[string]string
backendType string
// DNSName is the legacy field
// DNSNames supersedes DNSName.
DNSName string
MissingSAN string
DNSNames []*sqladmin.DnsNameMapping
useStandardTLSValidation bool
serverCAMode string
pscEnabled bool
signer SignFunc
clientSigner ClientSignFunc
certExpiry time.Time
// Key is the server's private key
Key *rsa.PrivateKey
// Cert is the server's certificate
Cert *x509.Certificate
// certs holds all of the certificates for this instance
certs *TLSCertificates
}
// String returns the instance connection name for the
// instance.
func (f FakeCSQLInstance) String() string {
return fmt.Sprintf("%v:%v:%v", f.project, f.region, f.name)
}
// serverCACert returns the current server CA cert.
func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
if f.signer != nil {
return f.signer(f.Cert, f.Key)
}
if !f.useStandardTLSValidation {
// legacy server mode, return only the server cert
return toPEMFormat(f.certs.serverCert)
}
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
}
// ClientCert creates an ephemeral client certificate signed with the Cloud SQL
// instance's private key. The return value is PEM encoded.
func (f FakeCSQLInstance) ClientCert(pubKey *rsa.PublicKey) ([]byte, error) {
if f.clientSigner != nil {
c, err := f.clientSigner(f.Cert, f.Key, pubKey)
if err != nil {
return c, err
}
return c, nil
}
return f.certs.signWithClientKey(pubKey)
}
// FakeCSQLInstanceOption is a function that configures a FakeCSQLInstance.
type FakeCSQLInstanceOption func(f *FakeCSQLInstance)
// WithPublicIP sets the public IP address to addr.
func WithPublicIP(addr string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.ipAddrs["PUBLIC"] = addr
}
}
// WithPrivateIP sets the private IP address to addr.
func WithPrivateIP(addr string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.ipAddrs["PRIVATE"] = addr
}
}
// WithPSC sets the PSC enabled.
func WithPSC(enabled bool) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.pscEnabled = enabled
}
}
// WithDNS sets the DnsName to addr.
func WithDNS(dns string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.DNSName = dns
}
}
// WithMissingSAN will cause the omit this dns name
// from the server cert, even though it is in the metadata.
func WithMissingSAN(dns string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.MissingSAN = dns
}
}
// WithDNSMapping adds the DnsNames records
func WithDNSMapping(name, dnsScope, connectionType string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.DNSNames = append(f.DNSNames,
&sqladmin.DnsNameMapping{
Name: name,
DnsScope: dnsScope,
ConnectionType: connectionType})
}
}
// WithCertExpiry sets the server certificate's expiration to t.
func WithCertExpiry(t time.Time) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.certExpiry = t
}
}
// WithRegion sets the server's region to the provided value.
func WithRegion(region string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.region = region
}
}
// WithFirstGenBackend sets the server backend type to FIRST_GEN.
func WithFirstGenBackend() FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.backendType = "FIRST_GEN"
}
}
// WithEngineVersion sets the "DB Version"
func WithEngineVersion(s string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.dbVersion = s
}
}
// SignFunc is a function that signs the certificate using the provided key. The
// result should be PEM-encoded.
type SignFunc = func(*x509.Certificate, *rsa.PrivateKey) ([]byte, error)
// WithCertSigner configures the signing function used to generate a signed
// certificate.
func WithCertSigner(s SignFunc) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.signer = s
}
}
// ClientSignFunc is a function that produces a certificate signed using the
// provided certificate, using the server's private key and the client's public
// key. The result should be PEM-encoded.
type ClientSignFunc = func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error)
// WithClientCertSigner configures the signing function used to generate a
// certificate signed with the client's public key.
func WithClientCertSigner(s ClientSignFunc) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.clientSigner = s
}
}
// WithNoIPAddrs configures a Fake Cloud SQL instance to have no IP
// addresses.
func WithNoIPAddrs() FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.ipAddrs = map[string]string{}
}
}
// WithServerCAMode sets the ServerCaMode of the instance.
func WithServerCAMode(serverCAMode string) FakeCSQLInstanceOption {
return func(f *FakeCSQLInstance) {
f.serverCAMode = serverCAMode
}
}
// NewFakeCSQLInstance returns a CloudSQLInst object for configuring mocks.
func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance {
f := FakeCSQLInstance{
project: project,
region: region,
name: name,
ipAddrs: map[string]string{"PUBLIC": "0.0.0.0"},
DNSName: "",
dbVersion: "POSTGRES_12", // default of no particular importance
backendType: "SECOND_GEN",
}
for _, o := range opts {
o(&f)
}
sanNames := make([]string, 0, 5)
if f.DNSName != "" && f.DNSName != f.MissingSAN {
sanNames = append(sanNames, f.DNSName)
}
for _, dnm := range f.DNSNames {
if dnm.Name != f.MissingSAN {
sanNames = append(sanNames, dnm.Name)
}
}
if len(sanNames) > 0 {
f.useStandardTLSValidation = true
}
certs := NewTLSCertificates(project, name, sanNames, f.certExpiry)
f.Key = certs.serverKey
f.Cert = certs.serverCert
f.certs = certs
return f
}
// SelfSign produces a PEM encoded certificate that is self-signed.
func SelfSign(c *x509.Certificate, k *rsa.PrivateKey) ([]byte, error) {
certBytes, err := x509.CreateCertificate(rand.Reader, c, c, &k.PublicKey, k)
if err != nil {
return nil, err
}
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, err
}
return certPEM.Bytes(), nil
}
// GenerateCertWithCommonName produces a certificate signed by the Fake Cloud
// SQL instance's CA with the specified common name cn.
func GenerateCertWithCommonName(i FakeCSQLInstance, cn string) []byte {
return i.certs.generateServerCertWithCn(cn).Raw
}
// StartServerProxy starts a fake server proxy and listens on the provided port
// on all interfaces, configured with TLS as specified by the FakeCSQLInstance.
// Callers should invoke the returned function to clean up all resources.
func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
ln, err := tls.Listen("tcp", ":3307", &tls.Config{
Certificates: i.certs.serverChain(i.useStandardTLSValidation),
ClientCAs: i.certs.clientCAPool(),
ClientAuth: tls.RequireAndVerifyClientCert,
})
if err != nil {
t.Fatalf("failed to start listener: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
for {
select {
case <-ctx.Done():
return
default:
conn, aErr := ln.Accept()
if opErr, ok := aErr.(net.Error); ok {
if opErr.Timeout() {
continue
}
return
}
if aErr == io.EOF {
return
}
if aErr != nil {
t.Logf("Fake server accept error: %v", aErr)
return
}
_, wErr := conn.Write([]byte(i.name))
if wErr != nil {
t.Logf("Fake server write error: %v", wErr)
}
_ = conn.Close()
}
}
}()
return func() {
cancel()
_ = ln.Close()
}
}
// RotateCA rotates all CA certificates and keys.
func RotateCA(inst FakeCSQLInstance) {
inst.certs.rotateCA()
}
// RotateClientCA rotates only client CA certificates and keys.
func RotateClientCA(inst FakeCSQLInstance) {
inst.certs.rotateClientCA()
}