internal/proxy/proxy.go (669 lines of code) (raw):
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"io"
"net"
"os"
"path"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"cloud.google.com/go/cloudsqlconn"
"github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql"
"github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/gcloud"
"golang.org/x/oauth2"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
"google.golang.org/api/sqladmin/v1"
)
var (
// Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
// Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT")
connNameRegex = regexp.MustCompile("([^:]+(:[^:]+)?):([^:]+):([^:]+)")
)
// connName represents the "instance connection name", in the format
// "project:region:name". Use the "parseConnName" method to initialize this
// struct.
type connName struct {
project string
region string
name string
}
func (c *connName) String() string {
return fmt.Sprintf("%s:%s:%s", c.project, c.region, c.name)
}
// parseConnName initializes a new connName struct.
func parseConnName(cn string) (connName, error) {
b := []byte(cn)
m := connNameRegex.FindSubmatch(b)
if m == nil {
return connName{}, fmt.Errorf(
"invalid instance connection name, want = PROJECT:REGION:INSTANCE, got = %v",
cn,
)
}
c := connName{
project: string(m[1]),
region: string(m[3]),
name: string(m[4]),
}
return c, nil
}
// InstanceConnConfig holds the configuration for an individual instance
// connection.
type InstanceConnConfig struct {
// Name is the instance connection name.
Name string
// Addr is the address on which to bind a listener for the instance.
Addr string
// Port is the port on which to bind a listener for the instance.
Port int
// UnixSocket is the directory where a Unix socket will be created,
// connected to the Cloud SQL instance. The full path to the socket will be
// UnixSocket + os.PathSeparator + Name. If set, takes precedence over Addr
// and Port.
UnixSocket string
// UnixSocketPath is the path where a Unix socket will be created,
// connected to the Cloud SQL instance. The full path to the socket will be
// UnixSocketPath. If this is a Postgres database, the proxy will ensure that
// the last path element is `.s.PGSQL.5432`, appending this path element if
// necessary. If set, UnixSocketPath takes precedence over UnixSocket, Addr
// and Port.
UnixSocketPath string
// IAMAuthN enables automatic IAM DB Authentication for the instance.
// MySQL and Postgres only. If it is nil, the value was not specified.
IAMAuthN *bool
// PrivateIP tells the proxy to attempt to connect to the db instance's
// private IP address instead of the public IP address
PrivateIP *bool
// PSC tells the proxy to attempt to connect to the db instance's
// private service connect endpoint
PSC *bool
}
// Config contains all the configuration provided by the caller.
type Config struct {
// Filepath is the path to a configuration file.
Filepath string
// UserAgent is the user agent to use when connecting to the cloudsql instance
UserAgent string
// Token is the Bearer token used for authorization.
Token string
// LoginToken is the Bearer token used for Auto IAM AuthN. Used only in
// conjunction with Token.
LoginToken string
// CredentialsFile is the path to a service account key.
CredentialsFile string
// CredentialsJSON is a JSON representation of the service account key.
CredentialsJSON string
// GcloudAuth set whether to use gcloud's config helper to retrieve a
// token for authentication.
GcloudAuth bool
// Addr is the address on which to bind all instances.
Addr string
// Port is the initial port to bind to. Subsequent instances bind to
// increments from this value.
Port int
// APIEndpointURL is the URL of the Google Cloud SQL Admin API. When left blank,
// the proxy will use the main public api: https://sqladmin.googleapis.com/
APIEndpointURL string
// UniverseDomain is the universe domain for the TPC environment. When left
// blank, the proxy will use the Google Default Universe (GDU): googleapis.com
UniverseDomain string
// UnixSocket is the directory where Unix sockets will be created,
// connected to any Instances. If set, takes precedence over Addr and Port.
UnixSocket string
// FUSEDir enables a file system in user space at the provided path that
// connects to the requested instance only when a client requests it.
FUSEDir string
// FUSETempDir sets the temporary directory where the FUSE mount will place
// Unix domain sockets connected to Cloud SQL instances. The temp directory
// is not accessed directly.
FUSETempDir string
// IAMAuthN enables automatic IAM DB Authentication for all instances.
// MySQL and Postgres only.
IAMAuthN bool
// MaxConnections are the maximum number of connections the Client may
// establish to the Cloud SQL server side proxy before refusing additional
// connections. A zero-value indicates no limit.
MaxConnections uint64
// WaitBeforeClose sets the duration to wait after receiving a shutdown signal
// but before closing the process. Not setting this field means to initiate
// the shutdown process immediately.
WaitBeforeClose time.Duration
// WaitOnClose sets the duration to wait for connections to close before
// shutting down. Not setting this field means to close immediately
// regardless of any open connections.
WaitOnClose time.Duration
// PrivateIP enables connections via the database server's private IP address
// for all instances.
PrivateIP bool
// PSC enables connections via the database server's private service connect
// endpoint for all instances
PSC bool
// AutoIP supports a legacy behavior where the Proxy will connect to
// the first IP address returned from the SQL ADmin API response. This
// setting should be avoided and used only to support legacy Proxy
// users.
AutoIP bool
// LazyRefresh configures the Go Connector to retrieve connection info
// lazily and as-needed. Otherwise, no background refresh cycle runs. This
// setting is useful in environments where the CPU may be throttled outside
// of a request context, e.g., Cloud Run.
LazyRefresh bool
// Instances are configuration for individual instances. Instance
// configuration takes precedence over global configuration.
Instances []InstanceConnConfig
// QuotaProject is the ID of the Google Cloud project to use to track
// API request quotas.
QuotaProject string
// ImpersonationChain is a comma separated list of one or more service
// accounts. The first entry in the chain is the impersonation target. Any
// additional service accounts after the target are delegates. The
// roles/iam.serviceAccountTokenCreator must be configured for each account
// that will be impersonated.
ImpersonationChain string
// StructuredLogs sets all output to use JSON in the LogEntry format.
// See https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry
StructuredLogs bool
// Quiet controls whether only error messages are logged.
Quiet bool
// TelemetryProject enables sending metrics and traces to the specified project.
TelemetryProject string
// TelemetryPrefix sets a prefix for all emitted metrics.
TelemetryPrefix string
// TelemetryTracingSampleRate sets the rate at which traces are
// samples. A higher value means fewer traces.
TelemetryTracingSampleRate int
// ExitZeroOnSigterm exits with 0 exit code when Sigterm received
ExitZeroOnSigterm bool
// DisableTraces disables tracing when TelemetryProject is set.
DisableTraces bool
// DisableMetrics disables metrics when TelemetryProject is set.
DisableMetrics bool
// Prometheus enables a Prometheus endpoint served at the address and
// port specified by HTTPAddress and HTTPPort.
Prometheus bool
// PrometheusNamespace configures the namespace under which metrics are written.
PrometheusNamespace string
// HealthCheck enables a health check server. It's address and port are
// specified by HTTPAddress and HTTPPort.
HealthCheck bool
// HTTPAddress sets the address for the health check and prometheus server.
HTTPAddress string
// HTTPPort sets the port for the health check and prometheus server.
HTTPPort string
// AdminPort configures the port for the localhost-only admin server.
AdminPort string
// Debug enables a debug handler on localhost.
Debug bool
// QuitQuitQuit enables a handler that will shut the Proxy down upon
// receiving a GET or POST request.
QuitQuitQuit bool
// DebugLogs enables debug level logging.
DebugLogs bool
// OtherUserAgents is a list of space separate user agents that will be
// appended to the default user agent.
OtherUserAgents string
// RunConnectionTest determines whether the Proxy should attempt a connection
// to all specified instances to verify the network path is valid.
RunConnectionTest bool
}
// dialOptions interprets appropriate dial options for a particular instance
// configuration
func dialOptions(c Config, i InstanceConnConfig) []cloudsqlconn.DialOption {
var opts []cloudsqlconn.DialOption
if i.IAMAuthN != nil {
opts = append(opts, cloudsqlconn.WithDialIAMAuthN(*i.IAMAuthN))
}
switch {
// If private IP is enabled at the instance level, or private IP is enabled globally
// add the option.
case i.PrivateIP != nil && *i.PrivateIP || i.PrivateIP == nil && c.PrivateIP:
opts = append(opts, cloudsqlconn.WithPrivateIP())
// If PSC is enabled at the instance level, or PSC is enabled globally
// add the option.
case i.PSC != nil && *i.PSC || i.PSC == nil && c.PSC:
opts = append(opts, cloudsqlconn.WithPSC())
case c.AutoIP:
opts = append(opts, cloudsqlconn.WithAutoIP())
default:
// assume public IP by default
}
return opts
}
func parseImpersonationChain(chain string) (string, []string) {
accts := strings.Split(chain, ",")
target := accts[0]
// Assign delegates if the chain is more than one account. Delegation
// goes from last back towards target, e.g., With sa1,sa2,sa3, sa3
// delegates to sa2, which impersonates the target sa1.
var delegates []string
if l := len(accts); l > 1 {
for i := l - 1; i > 0; i-- {
delegates = append(delegates, accts[i])
}
}
return target, delegates
}
const iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
func credentialsOpt(c Config, l cloudsql.Logger) (cloudsqlconn.Option, error) {
// If service account impersonation is configured, set up an impersonated
// credentials token source.
if c.ImpersonationChain != "" {
var iopts []option.ClientOption
switch {
case c.Token != "":
l.Infof("Impersonating service account with OAuth2 token")
iopts = append(iopts, option.WithTokenSource(
oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}),
))
case c.CredentialsFile != "":
l.Infof("Impersonating service account with the credentials file at %q", c.CredentialsFile)
iopts = append(iopts, option.WithCredentialsFile(c.CredentialsFile))
case c.CredentialsJSON != "":
l.Infof("Impersonating service account with JSON credentials environment variable")
iopts = append(iopts, option.WithCredentialsJSON([]byte(c.CredentialsJSON)))
case c.GcloudAuth:
l.Infof("Impersonating service account with gcloud user credentials")
ts, err := gcloud.TokenSource()
if err != nil {
return nil, err
}
iopts = append(iopts, option.WithTokenSource(ts))
default:
l.Infof("Impersonating service account with Application Default Credentials")
}
target, delegates := parseImpersonationChain(c.ImpersonationChain)
ts, err := impersonate.CredentialsTokenSource(
context.Background(),
impersonate.CredentialsConfig{
TargetPrincipal: target,
Delegates: delegates,
Scopes: []string{sqladmin.SqlserviceAdminScope},
},
iopts...,
)
if err != nil {
return nil, err
}
if c.IAMAuthN {
iamLoginTS, err := impersonate.CredentialsTokenSource(
context.Background(),
impersonate.CredentialsConfig{
TargetPrincipal: target,
Delegates: delegates,
Scopes: []string{iamLoginScope},
},
iopts...,
)
if err != nil {
return nil, err
}
return cloudsqlconn.WithIAMAuthNTokenSources(ts, iamLoginTS), nil
}
return cloudsqlconn.WithTokenSource(ts), nil
}
// Otherwise, configure credentials as usual.
var opt cloudsqlconn.Option
switch {
case c.Token != "":
l.Infof("Authorizing with OAuth2 token")
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token})
if c.IAMAuthN {
lts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.LoginToken})
opt = cloudsqlconn.WithIAMAuthNTokenSources(ts, lts)
} else {
opt = cloudsqlconn.WithTokenSource(ts)
}
case c.CredentialsFile != "":
l.Infof("Authorizing with the credentials file at %q", c.CredentialsFile)
opt = cloudsqlconn.WithCredentialsFile(c.CredentialsFile)
case c.CredentialsJSON != "":
l.Infof("Authorizing with JSON credentials environment variable")
opt = cloudsqlconn.WithCredentialsJSON([]byte(c.CredentialsJSON))
case c.GcloudAuth:
l.Infof("Authorizing with gcloud user credentials")
ts, err := gcloud.TokenSource()
if err != nil {
return nil, err
}
opt = cloudsqlconn.WithTokenSource(ts)
default:
l.Infof("Authorizing with Application Default Credentials")
// Return no-op options to avoid having to handle nil in caller code
opt = cloudsqlconn.WithOptions()
}
return opt, nil
}
// DialerOptions builds appropriate list of options from the Config
// values for use by cloudsqlconn.NewClient()
func (c *Config) DialerOptions(l cloudsql.Logger) ([]cloudsqlconn.Option, error) {
opts := []cloudsqlconn.Option{
cloudsqlconn.WithDNSResolver(),
cloudsqlconn.WithUserAgent(c.UserAgent),
}
co, err := credentialsOpt(*c, l)
if err != nil {
return nil, err
}
opts = append(opts, co)
if c.DebugLogs {
// nolint:staticcheck
opts = append(opts, cloudsqlconn.WithDebugLogger(l))
}
if c.APIEndpointURL != "" {
opts = append(opts, cloudsqlconn.WithAdminAPIEndpoint(c.APIEndpointURL))
}
if c.UniverseDomain != "" {
opts = append(opts, cloudsqlconn.WithUniverseDomain(c.UniverseDomain))
}
if c.IAMAuthN {
opts = append(opts, cloudsqlconn.WithIAMAuthN())
}
if c.QuotaProject != "" {
opts = append(opts, cloudsqlconn.WithQuotaProject(c.QuotaProject))
}
if c.LazyRefresh {
opts = append(opts, cloudsqlconn.WithLazyRefresh())
}
return opts, nil
}
type portConfig struct {
global int
postgres int
mysql int
sqlserver int
}
func newPortConfig(global int) *portConfig {
return &portConfig{
global: global,
postgres: 5432,
mysql: 3306,
sqlserver: 1433,
}
}
// nextPort returns the next port based on the initial global value.
func (c *portConfig) nextPort() int {
p := c.global
c.global++
return p
}
func (c *portConfig) nextDBPort(version string) int {
switch {
case strings.HasPrefix(version, "MYSQL"):
p := c.mysql
c.mysql++
return p
case strings.HasPrefix(version, "POSTGRES"):
p := c.postgres
c.postgres++
return p
case strings.HasPrefix(version, "SQLSERVER"):
p := c.sqlserver
c.sqlserver++
return p
default:
// Unexpected engine version, use global port setting instead.
return c.nextPort()
}
}
// Client proxies connections from a local client to the remote server side
// proxy for multiple Cloud SQL instances.
type Client struct {
// connCount tracks the number of all open connections from the Client to
// all Cloud SQL instances.
connCount uint64
// conf is the configuration used to initialize the Client.
conf *Config
dialer cloudsql.Dialer
// mnts is a list of all mounted sockets for this client
mnts []*socketMount
logger cloudsql.Logger
connRefuseNotify func()
fuseMount
}
// NewClient completes the initial setup required to get the proxy to a "steady"
// state.
func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config, connRefuseNotify func()) (*Client, error) {
// Check if the caller has configured a dialer.
// Otherwise, initialize a new one.
if d == nil {
dialerOpts, err := conf.DialerOptions(l)
if err != nil {
return nil, fmt.Errorf("error initializing dialer: %v", err)
}
d, err = cloudsqlconn.NewDialer(ctx, dialerOpts...)
if err != nil {
return nil, fmt.Errorf("error initializing dialer: %v", err)
}
}
c := &Client{
logger: l,
dialer: d,
connRefuseNotify: connRefuseNotify,
conf: conf,
}
if conf.FUSEDir != "" {
return configureFUSE(c, conf)
}
for _, inst := range conf.Instances {
// Initiate refresh operation and warm the cache.
go func(name string) { _, _ = d.EngineVersion(ctx, name) }(inst.Name)
}
var mnts []*socketMount
pc := newPortConfig(conf.Port)
for _, inst := range conf.Instances {
m, err := c.newSocketMount(ctx, conf, pc, inst)
if err != nil {
for _, m := range mnts {
mErr := m.Close()
if mErr != nil {
l.Errorf("failed to close mount: %v", mErr)
}
}
return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err)
}
l.Infof("[%s] Listening on %s", inst.Name, m.Addr())
mnts = append(mnts, m)
}
c.mnts = mnts
return c, nil
}
// CheckConnections dials each registered instance and reports the number of
// connections checked and any errors that may have occurred.
func (c *Client) CheckConnections(ctx context.Context) (int, error) {
var (
wg sync.WaitGroup
errCh = make(chan error, len(c.mnts))
mnts = c.mnts
)
if c.fuseDir != "" {
mnts = c.fuseMounts()
}
for _, mnt := range mnts {
wg.Add(1)
go func(m *socketMount) {
defer wg.Done()
conn, err := c.dialer.Dial(ctx, m.inst, m.dialOpts...)
if err != nil {
errCh <- err
return
}
cErr := conn.Close()
if cErr != nil {
c.logger.Errorf(
"connection check failed to close connection for %v: %v",
m.inst, cErr,
)
}
}(mnt)
}
wg.Wait()
var mErr MultiErr
for i := 0; i < len(mnts); i++ {
select {
case err := <-errCh:
mErr = append(mErr, err)
default:
continue
}
}
mLen := len(mnts)
if len(mErr) > 0 {
return mLen, mErr
}
return mLen, nil
}
// ConnCount returns the number of open connections and the maximum allowed
// connections. Returns 0 when the maximum allowed connections have not been set.
func (c *Client) ConnCount() (uint64, uint64) {
return atomic.LoadUint64(&c.connCount), c.conf.MaxConnections
}
// Serve starts proxying connections for all configured instances using the
// associated socket.
func (c *Client) Serve(ctx context.Context, notify func()) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if c.fuseDir != "" {
return c.serveFuse(ctx, notify)
}
if c.conf.RunConnectionTest {
c.logger.Infof("Connection test started")
if _, err := c.CheckConnections(ctx); err != nil {
c.logger.Errorf("Connection test failed")
return err
}
c.logger.Infof("Connection test passed")
}
exitCh := make(chan error)
for _, m := range c.mnts {
go func(mnt *socketMount) {
err := c.serveSocketMount(ctx, mnt)
if err != nil {
select {
// Best effort attempt to send error.
// If this send fails, it means the reading goroutine has
// already pulled a value out of the channel and is no longer
// reading any more values. In other words, we report only the
// first error.
case exitCh <- err:
default:
return
}
}
}(m)
}
notify()
return <-exitCh
}
// MultiErr is a group of errors wrapped into one.
type MultiErr []error
// Error returns a single string representing one or more errors.
func (m MultiErr) Error() string {
l := len(m)
if l == 1 {
return m[0].Error()
}
var errs []string
for _, e := range m {
errs = append(errs, e.Error())
}
return strings.Join(errs, ", ")
}
// Close triggers the proxyClient to shut down.
func (c *Client) Close() error {
mnts := c.mnts
var mErr MultiErr
// If FUSE is enabled, unmount it and save a reference to any existing
// socket mounts.
if c.fuseDir != "" {
if err := c.unmountFUSE(); err != nil {
mErr = append(mErr, err)
}
mnts = c.fuseMounts()
}
// Close the dialer to prevent any additional refreshes.
cErr := c.dialer.Close()
if cErr != nil {
mErr = append(mErr, cErr)
}
// Start a timer for clean shutdown (where all connections are closed).
// While the timer runs, additional connections will be accepted.
timeout := time.After(c.conf.WaitOnClose)
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
if atomic.LoadUint64(&c.connCount) > 0 {
continue
}
case <-timeout:
}
break
}
// Close all open socket listeners. Time to complete shutdown.
for _, m := range mnts {
err := m.Close()
if err != nil {
mErr = append(mErr, err)
}
}
if c.fuseDir != "" {
c.waitForFUSEMounts()
}
// Verify that all connections are closed.
open := atomic.LoadUint64(&c.connCount)
if c.conf.WaitOnClose > 0 && open > 0 {
openErr := fmt.Errorf(
"%d connection(s) still open after waiting %v", open, c.conf.WaitOnClose)
mErr = append(mErr, openErr)
}
if len(mErr) > 0 {
return mErr
}
return nil
}
// serveSocketMount persistently listens to the socketMounts listener and proxies connections to a
// given Cloud SQL instance.
func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error {
for {
cConn, err := s.Accept()
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
c.logger.Errorf("[%s] Error accepting connection: %v", s.inst, err)
// For transient errors, wait a small amount of time to see if it resolves itself
time.Sleep(10 * time.Millisecond)
continue
}
return err
}
// handle the connection in a separate goroutine
go func() {
c.logger.Infof("[%s] Accepted connection from %s", s.inst, cConn.RemoteAddr())
// A client has established a connection to the local socket. Before
// we initiate a connection to the Cloud SQL backend, increment the
// connection counter. If the total number of connections exceeds
// the maximum, refuse to connect and close the client connection.
count := atomic.AddUint64(&c.connCount, 1)
defer atomic.AddUint64(&c.connCount, ^uint64(0))
if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections {
c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections)
if c.connRefuseNotify != nil {
go c.connRefuseNotify()
}
_ = cConn.Close()
return
}
// give a max of 30 seconds to connect to the instance
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
sConn, err := c.dialer.Dial(ctx, s.inst, s.dialOpts...)
if err != nil {
c.logger.Errorf("[%s] failed to connect to instance: %v", s.inst, err)
_ = cConn.Close()
return
}
c.proxyConn(s.inst, cConn, sConn)
}()
}
}
// socketMount is a tcp/unix socket that listens for a Cloud SQL instance.
type socketMount struct {
inst string
listener net.Listener
dialOpts []cloudsqlconn.DialOption
}
func networkType(conf *Config, inst InstanceConnConfig) string {
if (conf.UnixSocket == "" && inst.UnixSocket == "" && inst.UnixSocketPath == "") ||
(inst.Addr != "" || inst.Port != 0) {
return "tcp"
}
return "unix"
}
func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) {
var (
// network is one of "tcp" or "unix"
network string
// address is either a TCP host port, or a Unix socket
address string
err error
)
// IF
// a global Unix socket directory is NOT set AND
// an instance-level Unix socket is NOT set
// (e.g., I didn't set a Unix socket globally or for this instance)
// OR
// an instance-level TCP address or port IS set
// (e.g., I'm overriding any global settings to use TCP for this
// instance)
// use a TCP listener.
// Otherwise, use a Unix socket.
if networkType(conf, inst) == "tcp" {
network = "tcp"
a := conf.Addr
if inst.Addr != "" {
a = inst.Addr
}
var np int
switch {
case inst.Port != 0:
np = inst.Port
case conf.Port != 0:
np = pc.nextPort()
default:
version, err := c.dialer.EngineVersion(ctx, inst.Name)
// Exit if the port is not specified for inactive instance
if err != nil {
c.logger.Errorf("[%v] could not resolve instance version: %v", inst.Name, err)
return nil, err
}
np = pc.nextDBPort(version)
}
address = net.JoinHostPort(a, fmt.Sprint(np))
} else {
network = "unix"
version, err := c.dialer.EngineVersion(ctx, inst.Name)
if err != nil {
c.logger.Errorf("[%v] could not resolve instance version: %v", inst.Name, err)
return nil, err
}
address, err = newUnixSocketMount(inst, conf.UnixSocket, strings.HasPrefix(version, "POSTGRES"))
if err != nil {
c.logger.Errorf("[%v] could not mount unix socket %q: %v", inst.Name, conf.UnixSocket, err)
return nil, err
}
}
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
ln, err := lc.Listen(ctx, network, address)
if err != nil {
c.logger.Errorf("[%v] could not listen to address %v: %v", inst.Name, address, err)
return nil, err
}
// Change file permissions to allow access for user, group, and other.
if network == "unix" {
// Best effort. If this call fails, group and other won't have write
// access.
_ = os.Chmod(address, 0777)
}
opts := dialOptions(*conf, inst)
m := &socketMount{inst: inst.Name, dialOpts: opts, listener: ln}
return m, nil
}
// newUnixSocketMount parses the configuration and returns the path to the unix
// socket, or an error if that path is not valid.
func newUnixSocketMount(inst InstanceConnConfig, unixSocketDir string, postgres bool) (string, error) {
var (
// the path to the unix socket
address string
// the parent directory of the unix socket
dir string
)
if inst.UnixSocketPath != "" {
// When UnixSocketPath is set
address = inst.UnixSocketPath
// If UnixSocketPath ends .s.PGSQL.5432, remove it for consistency
if postgres && path.Base(address) == ".s.PGSQL.5432" {
address = path.Dir(address)
}
dir = path.Dir(address)
} else {
// When UnixSocket is set
dir = unixSocketDir
if dir == "" {
dir = inst.UnixSocket
}
address = UnixAddress(dir, inst.Name)
}
// if base directory does not exist, fail
if _, err := os.Stat(dir); err != nil {
return "", err
}
// When setting up a listener for Postgres, create address as a
// directory, and use the Postgres-specific socket name
// .s.PGSQL.5432.
if postgres {
// Make the directory only if it hasn't already been created.
if _, err := os.Stat(address); err != nil {
if err = os.Mkdir(address, 0777); err != nil {
return "", err
}
}
address = UnixAddress(address, ".s.PGSQL.5432")
}
return address, nil
}
func (s *socketMount) Addr() net.Addr {
return s.listener.Addr()
}
func (s *socketMount) Accept() (net.Conn, error) {
return s.listener.Accept()
}
// Close stops the mount from listening for any more connections
func (s *socketMount) Close() error {
return s.listener.Close()
}
// proxyConn sets up a bidirectional copy between two open connections
func (c *Client) proxyConn(inst string, client, server net.Conn) {
// only allow the first side to give an error for terminating a connection
var o sync.Once
cleanup := func(errDesc string, isErr bool) {
o.Do(func() {
_ = client.Close()
_ = server.Close()
if isErr {
c.logger.Errorf(errDesc)
} else {
c.logger.Infof(errDesc)
}
})
}
// copy bytes from client to server
go func() {
buf := make([]byte, 8*1024) // 8kb
for {
n, cErr := client.Read(buf)
var sErr error
if n > 0 {
_, sErr = server.Write(buf[:n])
}
switch {
case cErr == io.EOF:
cleanup(fmt.Sprintf("[%s] client closed the connection", inst), false)
return
case cErr != nil:
cleanup(fmt.Sprintf("[%s] connection aborted - error reading from client: %v", inst, cErr), true)
return
case sErr == io.EOF:
cleanup(fmt.Sprintf("[%s] instance closed the connection", inst), false)
return
case sErr != nil:
cleanup(fmt.Sprintf("[%s] connection aborted - error writing to instance: %v", inst, sErr), true)
return
}
}
}()
// copy bytes from server to client
buf := make([]byte, 8*1024) // 8kb
for {
n, sErr := server.Read(buf)
var cErr error
if n > 0 {
_, cErr = client.Write(buf[:n])
}
switch {
case sErr == io.EOF:
cleanup(fmt.Sprintf("[%s] instance closed the connection", inst), false)
return
case sErr != nil:
cleanup(fmt.Sprintf("[%s] connection aborted - error reading from instance: %v", inst, sErr), true)
return
case cErr == io.EOF:
cleanup(fmt.Sprintf("[%s] client closed the connection", inst), false)
return
case cErr != nil:
cleanup(fmt.Sprintf("[%s] connection aborted - error writing to client: %v", inst, cErr), true)
return
}
}
}