connectionpool.go (447 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "crypto/tls" "crypto/x509" "errors" "fmt" "io/ioutil" "math/rand" "net" "sync" "sync/atomic" "time" ) // interface to implement to receive the host information type SetHosts interface { SetHosts(hosts []*HostInfo) } // interface to implement to receive the partitioner value type SetPartitioner interface { SetPartitioner(partitioner string) } func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { // Config.InsecureSkipVerify | EnableHostVerification | Result // Config is nil | true | verify host // Config is nil | false | do not verify host // false | false | verify host // true | false | do not verify host // false | true | verify host // true | true | verify host var tlsConfig *tls.Config if sslOpts.Config == nil { tlsConfig = &tls.Config{ InsecureSkipVerify: !sslOpts.EnableHostVerification, } } else { // use clone to avoid race. tlsConfig = sslOpts.Config.Clone() } if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification { tlsConfig.InsecureSkipVerify = false } // ca cert is optional if sslOpts.CaPath != "" { if tlsConfig.RootCAs == nil { tlsConfig.RootCAs = x509.NewCertPool() } pem, err := ioutil.ReadFile(sslOpts.CaPath) if err != nil { return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err) } if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) { return nil, errors.New("connectionpool: failed parsing or CA certs") } } if sslOpts.CertPath != "" || sslOpts.KeyPath != "" { mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath) if err != nil { return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err) } tlsConfig.Certificates = append(tlsConfig.Certificates, mycert) } return tlsConfig, nil } type policyConnPool struct { session *Session port int numConns int keyspace string mu sync.RWMutex hostConnPools map[string]*hostConnPool } func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { var ( err error hostDialer HostDialer ) hostDialer = cfg.HostDialer if hostDialer == nil { var tlsConfig *tls.Config // TODO(zariel): move tls config setup into session init. if cfg.SslOpts != nil { tlsConfig, err = setupTLSConfig(cfg.SslOpts) if err != nil { return nil, err } } dialer := cfg.Dialer if dialer == nil { d := &net.Dialer{ Timeout: cfg.ConnectTimeout, } if cfg.SocketKeepalive > 0 { d.KeepAlive = cfg.SocketKeepalive } dialer = d } hostDialer = &defaultHostDialer{ dialer: dialer, tlsConfig: tlsConfig, } } return &ConnConfig{ ProtoVersion: cfg.ProtoVersion, CQLVersion: cfg.CQLVersion, Timeout: cfg.Timeout, WriteTimeout: cfg.WriteTimeout, ConnectTimeout: cfg.ConnectTimeout, Dialer: cfg.Dialer, HostDialer: hostDialer, Compressor: cfg.Compressor, Authenticator: cfg.Authenticator, AuthProvider: cfg.AuthProvider, Keepalive: cfg.SocketKeepalive, Logger: cfg.logger(), }, nil } func newPolicyConnPool(session *Session) *policyConnPool { // create the pool pool := &policyConnPool{ session: session, port: session.cfg.Port, numConns: session.cfg.NumConns, keyspace: session.cfg.Keyspace, hostConnPools: map[string]*hostConnPool{}, } return pool } func (p *policyConnPool) SetHosts(hosts []*HostInfo) { p.mu.Lock() defer p.mu.Unlock() toRemove := make(map[string]struct{}) for hostID := range p.hostConnPools { toRemove[hostID] = struct{}{} } pools := make(chan *hostConnPool) createCount := 0 for _, host := range hosts { if !host.IsUp() { // don't create a connection pool for a down host continue } hostID := host.HostID() if _, exists := p.hostConnPools[hostID]; exists { // still have this host, so don't remove it delete(toRemove, hostID) continue } createCount++ go func(host *HostInfo) { // create a connection pool for the host pools <- newHostConnPool( p.session, host, p.port, p.numConns, p.keyspace, ) }(host) } // add created pools for createCount > 0 { pool := <-pools createCount-- if pool.Size() > 0 { // add pool only if there a connections available p.hostConnPools[pool.host.HostID()] = pool } } for addr := range toRemove { pool := p.hostConnPools[addr] delete(p.hostConnPools, addr) go pool.Close() } } func (p *policyConnPool) Size() int { p.mu.RLock() count := 0 for _, pool := range p.hostConnPools { count += pool.Size() } p.mu.RUnlock() return count } func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) { hostID := host.HostID() p.mu.RLock() pool, ok = p.hostConnPools[hostID] p.mu.RUnlock() return } func (p *policyConnPool) getPoolByHostID(hostID string) (pool *hostConnPool, ok bool) { p.mu.RLock() pool, ok = p.hostConnPools[hostID] p.mu.RUnlock() return } func (p *policyConnPool) Close() { p.mu.Lock() defer p.mu.Unlock() // close the pools for addr, pool := range p.hostConnPools { delete(p.hostConnPools, addr) pool.Close() } } func (p *policyConnPool) addHost(host *HostInfo) { hostID := host.HostID() p.mu.Lock() pool, ok := p.hostConnPools[hostID] if !ok { pool = newHostConnPool( p.session, host, host.Port(), // TODO: if port == 0 use pool.port? p.numConns, p.keyspace, ) p.hostConnPools[hostID] = pool } p.mu.Unlock() pool.fill() } func (p *policyConnPool) removeHost(hostID string) { p.mu.Lock() pool, ok := p.hostConnPools[hostID] if !ok { p.mu.Unlock() return } delete(p.hostConnPools, hostID) p.mu.Unlock() go pool.Close() } // hostConnPool is a connection pool for a single host. // Connection selection is based on a provided ConnSelectionPolicy type hostConnPool struct { session *Session host *HostInfo port int size int keyspace string // protection for conns, closed, filling mu sync.RWMutex conns []*Conn closed bool filling bool pos uint32 logger StdLogger } func (h *hostConnPool) String() string { h.mu.RLock() defer h.mu.RUnlock() return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]", h.filling, h.closed, len(h.conns), h.size, h.host) } func newHostConnPool(session *Session, host *HostInfo, port, size int, keyspace string) *hostConnPool { pool := &hostConnPool{ session: session, host: host, port: port, size: size, keyspace: keyspace, conns: make([]*Conn, 0, size), filling: false, closed: false, logger: session.logger, } // the pool is not filled or connected return pool } // Pick a connection from this connection pool for the given query. func (pool *hostConnPool) Pick() *Conn { pool.mu.RLock() defer pool.mu.RUnlock() if pool.closed { return nil } size := len(pool.conns) if size < pool.size { // try to fill the pool go pool.fill() if size == 0 { return nil } } pos := int(atomic.AddUint32(&pool.pos, 1) - 1) var ( leastBusyConn *Conn streamsAvailable int ) // find the conn which has the most available streams, this is racy for i := 0; i < size; i++ { conn := pool.conns[(pos+i)%size] if streams := conn.AvailableStreams(); streams > streamsAvailable { leastBusyConn = conn streamsAvailable = streams } } return leastBusyConn } // Size returns the number of connections currently active in the pool func (pool *hostConnPool) Size() int { pool.mu.RLock() defer pool.mu.RUnlock() return len(pool.conns) } // Close the connection pool func (pool *hostConnPool) Close() { pool.mu.Lock() if pool.closed { pool.mu.Unlock() return } pool.closed = true // ensure we dont try to reacquire the lock in handleError // TODO: improve this as the following can happen // 1) we have locked pool.mu write lock // 2) conn.Close calls conn.closeWithError(nil) // 3) conn.closeWithError calls conn.Close() which returns an error // 4) conn.closeWithError calls pool.HandleError with the error from conn.Close // 5) pool.HandleError tries to lock pool.mu // deadlock // empty the pool conns := pool.conns pool.conns = nil pool.mu.Unlock() // close the connections for _, conn := range conns { conn.Close() } } // Fill the connection pool func (pool *hostConnPool) fill() { pool.mu.RLock() // avoid filling a closed pool, or concurrent filling if pool.closed || pool.filling { pool.mu.RUnlock() return } // determine the filling work to be done startCount := len(pool.conns) fillCount := pool.size - startCount // avoid filling a full (or overfull) pool if fillCount <= 0 { pool.mu.RUnlock() return } // switch from read to write lock pool.mu.RUnlock() pool.mu.Lock() // double check everything since the lock was released startCount = len(pool.conns) fillCount = pool.size - startCount if pool.closed || pool.filling || fillCount <= 0 { // looks like another goroutine already beat this // goroutine to the filling pool.mu.Unlock() return } // ok fill the pool pool.filling = true // allow others to access the pool while filling pool.mu.Unlock() // only this goroutine should make calls to fill/empty the pool at this // point until after this routine or its subordinates calls // fillingStopped // fill only the first connection synchronously if startCount == 0 { err := pool.connect() pool.logConnectErr(err) if err != nil { // probably unreachable host pool.fillingStopped(err) return } // notify the session that this node is connected go pool.session.handleNodeConnected(pool.host) // filled one fillCount-- } // fill the rest of the pool asynchronously go func() { err := pool.connectMany(fillCount) // mark the end of filling pool.fillingStopped(err) if err == nil && startCount > 0 { // notify the session that this node is connected again go pool.session.handleNodeConnected(pool.host) } }() } func (pool *hostConnPool) logConnectErr(err error) { if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") { // connection refused // these are typical during a node outage so avoid log spam. if gocqlDebug { pool.logger.Printf("gocql: unable to dial %q: %v\n", pool.host, err) } } else if err != nil { // unexpected error pool.logger.Printf("error: failed to connect to %q due to error: %v", pool.host, err) } } // transition back to a not-filling state. func (pool *hostConnPool) fillingStopped(err error) { if err != nil { if gocqlDebug { pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err) } // wait for some time to avoid back-to-back filling // this provides some time between failed attempts // to fill the pool for the host to recover time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond) } pool.mu.Lock() pool.filling = false count := len(pool.conns) host := pool.host port := pool.port pool.mu.Unlock() // if we errored and the size is now zero, make sure the host is marked as down // see https://github.com/apache/cassandra-gocql-driver/issues/1614 if gocqlDebug { pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count) } if err != nil && count == 0 { if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) { pool.session.handleNodeDown(host.ConnectAddress(), port) } } } // connectMany creates new connections concurrent. func (pool *hostConnPool) connectMany(count int) error { if count == 0 { return nil } var ( wg sync.WaitGroup mu sync.Mutex connectErr error ) wg.Add(count) for i := 0; i < count; i++ { go func() { defer wg.Done() err := pool.connect() pool.logConnectErr(err) if err != nil { mu.Lock() connectErr = err mu.Unlock() } }() } // wait for all connections are done wg.Wait() return connectErr } // create a new connection to the host and add it to the pool func (pool *hostConnPool) connect() (err error) { // TODO: provide a more robust connection retry mechanism, we should also // be able to detect hosts that come up by trying to connect to downed ones. // try to connect var conn *Conn reconnectionPolicy := pool.session.cfg.ReconnectionPolicy for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ { conn, err = pool.session.connect(pool.session.ctx, pool.host, pool) if err == nil { break } if opErr, isOpErr := err.(*net.OpError); isOpErr { // if the error is not a temporary error (ex: network unreachable) don't // retry if !opErr.Temporary() { break } } if gocqlDebug { pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n", pool.host.ConnectAddress(), err, reconnectionPolicy) } time.Sleep(reconnectionPolicy.GetInterval(i)) } if err != nil { return err } if pool.keyspace != "" { // set the keyspace if err = conn.UseKeyspace(pool.keyspace); err != nil { conn.Close() return err } } // add the Conn to the pool pool.mu.Lock() defer pool.mu.Unlock() if pool.closed { conn.Close() return nil } pool.conns = append(pool.conns, conn) return nil } // handle any error from a Conn func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { if !closed { // still an open connection, so continue using it return } // TODO: track the number of errors per host and detect when a host is dead, // then also have something which can detect when a host comes back. pool.mu.Lock() defer pool.mu.Unlock() if pool.closed { // pool closed return } if gocqlDebug { pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err) } // find the connection index for i, candidate := range pool.conns { if candidate == conn { // remove the connection, not preserving order pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1] // lost a connection, so fill the pool go pool.fill() break } } }