internal/remoting/pool/single_pool.go (119 lines of code) (raw):

/* * Copyright (c) 2023 Alibaba Group Holding Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package pool import ( "context" "net" "sync" "github.com/alibaba/schedulerx-worker-go/logger" ) var ( _ ConnPool = &singleConnPool{} connPool ConnPool once sync.Once lock sync.RWMutex ) func InitConnPool(pool ConnPool) { once.Do(func() { connPool = pool }) } // GetConnPool first executes InitConnPool and then calls it, otherwise it returns nil func GetConnPool() ConnPool { lock.RLock() defer lock.RUnlock() return connPool } type ConnPool interface { Get(ctx context.Context) (net.Conn, error) ReconnectTrigger() chan struct{} } type singleConnPool struct { lock sync.RWMutex conn net.Conn dialer func() (net.Conn, error) reconnectSignalCh chan struct{} options *Options } type Options struct { postDialer func(context.Context, net.Conn) error addrChangedSignalCh chan struct{} } type Option func(*Options) func WithPostDialer(postDialer func(context.Context, net.Conn) error) Option { return func(o *Options) { o.postDialer = postDialer } } func WithAddrChangedSignalCh(addrChangedSignalCh chan struct{}) Option { return func(o *Options) { o.addrChangedSignalCh = addrChangedSignalCh } } func NewSingleConnPool(ctx context.Context, dialer func() (net.Conn, error), opts ...Option) ConnPool { options := new(Options) for _, opt := range opts { opt(options) } pool := &singleConnPool{ dialer: dialer, reconnectSignalCh: make(chan struct{}, 3), options: options, } // network is broken or heartbeat timeout go pool.onReconnectTrigger(ctx) // server addr changed if options.addrChangedSignalCh != nil { go pool.onAddrChanged(ctx) } return pool } func (p *singleConnPool) newConn(ctx context.Context) (net.Conn, error) { p.clean() conn, err := p.dialer() if err != nil { return nil, err } p.lock.Lock() defer p.lock.Unlock() p.conn = conn // handshake success means connection is truly established if postDialer := p.options.postDialer; postDialer != nil { if err := postDialer(ctx, conn); err != nil { return nil, err } } return conn, nil } func (p *singleConnPool) Get(ctx context.Context) (net.Conn, error) { if !p.isConnExisted() { // create a new connection if there is no existing connection return p.newConn(ctx) } p.lock.RLock() defer p.lock.RUnlock() return p.conn, nil } func (p *singleConnPool) ReconnectTrigger() chan struct{} { return p.reconnectSignalCh } func (p *singleConnPool) isConnExisted() bool { p.lock.RLock() defer p.lock.RUnlock() return p.conn != nil } func (p *singleConnPool) onReconnectTrigger(ctx context.Context) { for range p.reconnectSignalCh { if _, err := p.newConn(ctx); err != nil { logger.Errorf("Reconnect server failed after connection isn't available, err=%s", err.Error()) } } } func (p *singleConnPool) onAddrChanged(ctx context.Context) { for range p.options.addrChangedSignalCh { if _, err := p.newConn(ctx); err != nil { logger.Errorf("Reconnect server failed after addr if changed, err=%s", err.Error()) } } } func (p *singleConnPool) clean() { p.lock.Lock() defer p.lock.Unlock() if p.conn != nil { _ = p.conn.Close() p.conn = nil } }