pulsar/internal/connection.go (910 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package internal import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/url" "os" "sync" "sync/atomic" "time" "github.com/apache/pulsar-client-go/pulsar/auth" "google.golang.org/protobuf/proto" pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto" "github.com/apache/pulsar-client-go/pulsar/log" ) const ( PulsarProtocolVersion = int32(pb.ProtocolVersion_v20) ) type TLSOptions struct { KeyFile string CertFile string TrustCertsFilePath string AllowInsecureConnection bool ValidateHostname bool ServerName string CipherSuites []uint16 MinVersion uint16 MaxVersion uint16 TLSConfig *tls.Config } var ( errConnectionClosed = errors.New("connection closed") errUnableRegisterListener = errors.New("unable register listener when con closed") errUnableAddConsumeHandler = errors.New("unable add consumer handler when con closed") ) // ConnectionListener is a user of a connection (eg. a producer or // a consumer) that can register itself to get notified // when the connection is closed. type ConnectionListener interface { // ReceivedSendReceipt receive and process the return value of the send command. ReceivedSendReceipt(response *pb.CommandSendReceipt) // ConnectionClosed close the TCP connection. ConnectionClosed(closeProducer *pb.CommandCloseProducer) // SetRedirectedClusterURI set the redirected cluster URI for lookups SetRedirectedClusterURI(redirectedClusterURI string) } // Connection is a interface of client cnx. type Connection interface { SendRequest(requestID uint64, req *pb.BaseCommand, callback func(*pb.BaseCommand, error)) SendRequestNoWait(req *pb.BaseCommand) error WriteData(ctx context.Context, data Buffer) RegisterListener(id uint64, listener ConnectionListener) error UnregisterListener(id uint64) AddConsumeHandler(id uint64, handler ConsumerHandler) error DeleteConsumeHandler(id uint64) ID() string GetMaxMessageSize() int32 Close() WaitForClose() <-chan struct{} IsProxied() bool } type ConsumerHandler interface { MessageReceived(response *pb.CommandMessage, headersAndPayload Buffer) error ActiveConsumerChanged(isActive bool) // ConnectionClosed close the TCP connection. ConnectionClosed(closeConsumer *pb.CommandCloseConsumer) // SetRedirectedClusterURI set the redirected cluster URI for lookups SetRedirectedClusterURI(redirectedClusterURI string) } type connectionState int32 const ( connectionInit = iota connectionReady connectionClosed ) func (s connectionState) String() string { switch s { case connectionInit: return "Initializing" case connectionReady: return "Ready" case connectionClosed: return "Closed" default: return "Unknown" } } type request struct { id *uint64 cmd *pb.BaseCommand callback func(command *pb.BaseCommand, err error) } type dataRequest struct { ctx context.Context data Buffer } type connection struct { started int32 connectionTimeout time.Duration closeOnce sync.Once // mu protects the fields below against concurrency accesses. mu sync.RWMutex state atomic.Int32 cnx net.Conn listeners map[uint64]ConnectionListener consumerHandlers map[uint64]ConsumerHandler logicalAddr *url.URL physicalAddr *url.URL writeBufferLock sync.Mutex writeBuffer Buffer reader *connectionReader lastDataReceivedLock sync.Mutex lastDataReceivedTime time.Time log log.Logger incomingRequestsWG sync.WaitGroup incomingRequestsCh chan *request closeCh chan struct{} readyCh chan struct{} writeRequestsCh chan *dataRequest pendingLock sync.Mutex pendingReqs map[uint64]*request tlsOptions *TLSOptions auth auth.Provider maxMessageSize int32 metrics *Metrics keepAliveInterval time.Duration lastActive time.Time description string } // connectionOptions defines configurations for creating connection. type connectionOptions struct { logicalAddr *url.URL physicalAddr *url.URL tls *TLSOptions connectionTimeout time.Duration auth auth.Provider logger log.Logger metrics *Metrics keepAliveInterval time.Duration description string } func newConnection(opts connectionOptions) *connection { cnx := &connection{ connectionTimeout: opts.connectionTimeout, keepAliveInterval: opts.keepAliveInterval, logicalAddr: opts.logicalAddr, physicalAddr: opts.physicalAddr, writeBuffer: NewBuffer(4096), log: opts.logger.SubLogger(log.Fields{"remote_addr": opts.physicalAddr}), pendingReqs: make(map[uint64]*request), lastDataReceivedTime: time.Now(), tlsOptions: opts.tls, auth: opts.auth, closeCh: make(chan struct{}), readyCh: make(chan struct{}), incomingRequestsCh: make(chan *request, 10), // This channel is used to pass data from producers to the connection // go routine. It can become contended or blocking if we have multiple // partition produces writing on a single connection. In general it's // good to keep this above the number of partition producers assigned // to a single connection. writeRequestsCh: make(chan *dataRequest, 256), listeners: make(map[uint64]ConnectionListener), consumerHandlers: make(map[uint64]ConsumerHandler), metrics: opts.metrics, description: opts.description, } cnx.state.Store(int32(connectionInit)) cnx.reader = newConnectionReader(cnx) return cnx } func (c *connection) start() { if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { c.log.Warnf("connection has already started") return } // Each connection gets its own goroutine that will go func() { if c.connect() { if c.doHandshake() { c.metrics.ConnectionsOpened.Inc() c.run() } else { c.metrics.ConnectionsHandshakeErrors.Inc() c.Close() } } else { c.metrics.ConnectionsEstablishmentErrors.Inc() c.Close() } }() } func (c *connection) connect() bool { c.log.Info("Connecting to broker") var ( err error cnx net.Conn tlsConfig *tls.Config ) if c.tlsOptions == nil { // Clear text connection if c.connectionTimeout.Nanoseconds() > 0 { cnx, err = net.DialTimeout("tcp", c.physicalAddr.Host, c.connectionTimeout) } else { cnx, err = net.Dial("tcp", c.physicalAddr.Host) } } else { // TLS connection tlsConfig, err = c.getTLSConfig() if err != nil { c.log.WithError(err).Warn("Failed to configure TLS ") return false } // time.Duration is initialized to 0 by default, net.Dialer's default timeout is no timeout // therefore if c.connectionTimeout is 0, it means no timeout d := &net.Dialer{Timeout: c.connectionTimeout} cnx, err = tls.DialWithDialer(d, "tcp", c.physicalAddr.Host, tlsConfig) } if err != nil { c.log.WithError(err).Warn("Failed to connect to broker.") c.Close() return false } c.mu.Lock() c.cnx = cnx c.log = c.log.SubLogger(log.Fields{"local_addr": c.cnx.LocalAddr()}) c.log.Info("TCP connection established") c.mu.Unlock() return true } func (c *connection) doHandshake() bool { // Send 'Connect' command to initiate handshake authData, err := c.auth.GetData() if err != nil { c.log.WithError(err).Warn("Failed to load auth credentials") return false } // During the initial handshake, the internal keep alive is not // active yet, so we need to timeout write and read requests c.cnx.SetDeadline(time.Now().Add(c.keepAliveInterval)) cmdConnect := &pb.CommandConnect{ ProtocolVersion: proto.Int32(PulsarProtocolVersion), ClientVersion: proto.String(c.getClientVersion()), AuthMethodName: proto.String(c.auth.Name()), AuthData: authData, FeatureFlags: &pb.FeatureFlags{ SupportsAuthRefresh: proto.Bool(true), SupportsBrokerEntryMetadata: proto.Bool(true), }, } if c.IsProxied() { cmdConnect.ProxyToBrokerUrl = proto.String(c.logicalAddr.Host) } c.writeCommand(baseCommand(pb.BaseCommand_CONNECT, cmdConnect)) cmd, _, err := c.reader.readSingleCommand() if err != nil { c.log.WithError(err).Warn("Failed to perform initial handshake") return false } // Reset the deadline so that we don't use read timeouts c.cnx.SetDeadline(time.Time{}) if cmd.Connected == nil { c.log.Warnf("Failed to establish connection with broker: '%s'", cmd.Error.GetMessage()) return false } if cmd.Connected.MaxMessageSize != nil && *cmd.Connected.MaxMessageSize > 0 { c.log.Debug("Got MaxMessageSize from handshake response:", *cmd.Connected.MaxMessageSize) c.maxMessageSize = *cmd.Connected.MaxMessageSize } else { c.log.Debug("No MaxMessageSize from handshake response, use default: ", MaxMessageSize) c.maxMessageSize = MaxMessageSize } c.log.Info("Connection is ready") c.setLastDataReceived(time.Now()) c.setStateReady() close(c.readyCh) // broadcast the readiness of the connection. return true } func (c *connection) getClientVersion() string { var clientVersion string if c.description == "" { clientVersion = ClientVersionString } else { clientVersion = fmt.Sprintf("%s-%s", ClientVersionString, c.description) } return clientVersion } func (c *connection) IsProxied() bool { return c.logicalAddr.Host != c.physicalAddr.Host } func (c *connection) waitUntilReady() error { select { case <-c.readyCh: return nil case <-c.closeCh: // Connection has been closed while waiting for the readiness. return errors.New("connection error") } } func (c *connection) failLeftRequestsWhenClose() { // wait for outstanding incoming requests to complete before draining // and closing the channel c.incomingRequestsWG.Wait() ch := c.incomingRequestsCh go func() { // send a nil message to drain instead of // closing the channel and causing a potential panic // // if other requests come in after the nil message // then the RPC client will time out ch <- nil }() for req := range ch { if nil == req { break // we have drained the requests } c.internalSendRequest(req) } } func (c *connection) run() { pingSendTicker := time.NewTicker(c.keepAliveInterval) pingCheckTicker := time.NewTicker(c.keepAliveInterval) defer func() { // stop tickers pingSendTicker.Stop() pingCheckTicker.Stop() // all the accesses to the pendingReqs should be happened in this run loop thread, // including the final cleanup, to avoid the issue // https://github.com/apache/pulsar-client-go/issues/239 c.failPendingRequests(errConnectionClosed) c.Close() }() // All reads come from the reader goroutine go c.reader.readFromConnection() go c.runPingCheck(pingCheckTicker) c.log.Debugf("Connection run starting with request capacity=%d queued=%d", cap(c.incomingRequestsCh), len(c.incomingRequestsCh)) for { select { case <-c.closeCh: c.failLeftRequestsWhenClose() return case req := <-c.incomingRequestsCh: if req == nil { return // TODO: this never gonna be happen } c.internalSendRequest(req) case req := <-c.writeRequestsCh: if req == nil { return } c.internalWriteData(req.ctx, req.data) case <-pingSendTicker.C: c.sendPing() } } } func (c *connection) runPingCheck(pingCheckTicker *time.Ticker) { for { select { case <-c.closeCh: return case <-pingCheckTicker.C: if c.lastDataReceived().Add(2 * c.keepAliveInterval).Before(time.Now()) { // We have not received a response to the previous Ping request, the // connection to broker is stale c.log.Warn("Detected stale connection to broker") c.Close() return } } } } func (c *connection) WriteData(ctx context.Context, data Buffer) { select { case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}: // Channel is not full return case <-ctx.Done(): c.log.Debug("Write data context cancelled") return default: // Channel full, fallback to probe if connection is closed } for { select { case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}: // Successfully wrote on the channel return case <-ctx.Done(): c.log.Debug("Write data context cancelled") return case <-time.After(100 * time.Millisecond): // The channel is either: // 1. blocked, in which case we need to wait until we have space // 2. the connection is already closed, then we need to bail out c.log.Debug("Couldn't write on connection channel immediately") if c.getState() != connectionReady { c.log.Debug("Connection was already closed") return } } } } func (c *connection) internalWriteData(ctx context.Context, data Buffer) { c.log.Debug("Write data: ", data.ReadableBytes()) select { case <-ctx.Done(): return default: if _, err := c.cnx.Write(data.ReadableSlice()); err != nil { c.log.WithError(err).Warn("Failed to write on connection") c.Close() } } } func (c *connection) writeCommand(cmd *pb.BaseCommand) { // Wire format // [FRAME_SIZE] [CMD_SIZE][CMD] cmdSize := uint32(proto.Size(cmd)) frameSize := cmdSize + 4 c.writeBufferLock.Lock() defer c.writeBufferLock.Unlock() c.writeBuffer.Clear() c.writeBuffer.WriteUint32(frameSize) c.writeBuffer.WriteUint32(cmdSize) c.writeBuffer.ResizeIfNeeded(cmdSize) err := MarshalToSizedBuffer(cmd, c.writeBuffer.WritableSlice()[:cmdSize]) if err != nil { c.log.WithError(err).Error("Protobuf serialization error") panic("Protobuf serialization error") } c.writeBuffer.WrittenBytes(cmdSize) c.internalWriteData(context.Background(), c.writeBuffer) } func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload Buffer) { c.log.Debugf("Received command: %s -- payload: %v", cmd, headersAndPayload) c.setLastDataReceived(time.Now()) switch *cmd.Type { case pb.BaseCommand_SUCCESS: c.handleResponse(cmd.Success.GetRequestId(), cmd) case pb.BaseCommand_PRODUCER_SUCCESS: if !cmd.ProducerSuccess.GetProducerReady() { request, ok := c.findPendingRequest(cmd.ProducerSuccess.GetRequestId()) if ok { request.callback(cmd, nil) } } else { c.handleResponse(cmd.ProducerSuccess.GetRequestId(), cmd) } case pb.BaseCommand_PARTITIONED_METADATA_RESPONSE: c.checkServerError(cmd.PartitionMetadataResponse.Error) c.handleResponse(cmd.PartitionMetadataResponse.GetRequestId(), cmd) case pb.BaseCommand_LOOKUP_RESPONSE: lookupResult := cmd.LookupTopicResponse c.checkServerError(lookupResult.Error) c.handleResponse(lookupResult.GetRequestId(), cmd) case pb.BaseCommand_CONSUMER_STATS_RESPONSE: c.handleResponse(cmd.ConsumerStatsResponse.GetRequestId(), cmd) case pb.BaseCommand_GET_LAST_MESSAGE_ID_RESPONSE: c.handleResponse(cmd.GetLastMessageIdResponse.GetRequestId(), cmd) case pb.BaseCommand_GET_TOPICS_OF_NAMESPACE_RESPONSE: c.handleResponse(cmd.GetTopicsOfNamespaceResponse.GetRequestId(), cmd) case pb.BaseCommand_GET_SCHEMA_RESPONSE: c.handleResponse(cmd.GetSchemaResponse.GetRequestId(), cmd) case pb.BaseCommand_GET_OR_CREATE_SCHEMA_RESPONSE: c.handleResponse(cmd.GetOrCreateSchemaResponse.GetRequestId(), cmd) case pb.BaseCommand_ERROR: c.handleResponseError(cmd.GetError()) case pb.BaseCommand_SEND_ERROR: c.handleSendError(cmd.GetSendError()) case pb.BaseCommand_CLOSE_PRODUCER: c.handleCloseProducer(cmd.GetCloseProducer()) case pb.BaseCommand_CLOSE_CONSUMER: c.handleCloseConsumer(cmd.GetCloseConsumer()) case pb.BaseCommand_TOPIC_MIGRATED: c.handleTopicMigrated(cmd.GetTopicMigrated()) case pb.BaseCommand_AUTH_CHALLENGE: c.handleAuthChallenge(cmd.GetAuthChallenge()) case pb.BaseCommand_SEND_RECEIPT: c.handleSendReceipt(cmd.GetSendReceipt()) case pb.BaseCommand_MESSAGE: c.handleMessage(cmd.GetMessage(), headersAndPayload) case pb.BaseCommand_ACK_RESPONSE: c.handleAckResponse(cmd.GetAckResponse()) case pb.BaseCommand_PING: c.handlePing() case pb.BaseCommand_PONG: c.handlePong() case pb.BaseCommand_TC_CLIENT_CONNECT_RESPONSE: c.handleResponse(cmd.TcClientConnectResponse.GetRequestId(), cmd) case pb.BaseCommand_NEW_TXN_RESPONSE: c.handleResponse(cmd.NewTxnResponse.GetRequestId(), cmd) case pb.BaseCommand_ADD_PARTITION_TO_TXN_RESPONSE: c.handleResponse(cmd.AddPartitionToTxnResponse.GetRequestId(), cmd) case pb.BaseCommand_ADD_SUBSCRIPTION_TO_TXN_RESPONSE: c.handleResponse(cmd.AddSubscriptionToTxnResponse.GetRequestId(), cmd) case pb.BaseCommand_END_TXN_RESPONSE: c.handleResponse(cmd.EndTxnResponse.GetRequestId(), cmd) case pb.BaseCommand_ACTIVE_CONSUMER_CHANGE: c.handleActiveConsumerChange(cmd.GetActiveConsumerChange()) default: c.log.Errorf("Received invalid command type: %s", cmd.Type) c.Close() } } func (c *connection) checkServerError(err *pb.ServerError) { if err == nil { return } if *err == pb.ServerError_ServiceNotReady { c.Close() } } func (c *connection) SendRequest(requestID uint64, req *pb.BaseCommand, callback func(command *pb.BaseCommand, err error)) { c.incomingRequestsWG.Add(1) defer c.incomingRequestsWG.Done() if c.getState() == connectionClosed { callback(req, ErrConnectionClosed) } else { select { case <-c.closeCh: callback(req, ErrConnectionClosed) case c.incomingRequestsCh <- &request{ id: &requestID, cmd: req, callback: callback, }: } } } func (c *connection) SendRequestNoWait(req *pb.BaseCommand) error { c.incomingRequestsWG.Add(1) defer c.incomingRequestsWG.Done() if c.getState() == connectionClosed { return ErrConnectionClosed } select { case <-c.closeCh: return ErrConnectionClosed case c.incomingRequestsCh <- &request{ id: nil, cmd: req, callback: nil, }: return nil } } func (c *connection) internalSendRequest(req *request) { if c.closed() { c.log.Warnf("internalSendRequest failed for connectionClosed") if req.callback != nil { req.callback(req.cmd, ErrConnectionClosed) } } else { c.pendingLock.Lock() if req.id != nil { c.pendingReqs[*req.id] = req } c.pendingLock.Unlock() c.writeCommand(req.cmd) } } func (c *connection) handleResponse(requestID uint64, response *pb.BaseCommand) { request, ok := c.deletePendingRequest(requestID) if !ok { c.log.Warnf("Received unexpected response for request %d of type %s", requestID, response.Type) return } request.callback(response, nil) } func (c *connection) handleResponseError(serverError *pb.CommandError) { requestID := serverError.GetRequestId() request, ok := c.deletePendingRequest(requestID) if !ok { c.log.Warnf("Received unexpected error response for request %d of type %s", requestID, serverError.GetError()) return } errMsg := fmt.Sprintf("server error: %s: %s", serverError.GetError(), serverError.GetMessage()) request.callback(nil, errors.New(errMsg)) } func (c *connection) handleAckResponse(ackResponse *pb.CommandAckResponse) { requestID := ackResponse.GetRequestId() consumerID := ackResponse.GetConsumerId() request, ok := c.deletePendingRequest(requestID) if !ok { c.log.Warnf("AckResponse has complete when receive response! requestId : %d, consumerId : %d", requestID, consumerID) return } if ackResponse.GetMessage() == "" { request.callback(nil, nil) return } errMsg := fmt.Sprintf("ack response error: %s: %s", ackResponse.GetError(), ackResponse.GetMessage()) request.callback(nil, errors.New(errMsg)) } func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) { producerID := response.GetProducerId() c.mu.RLock() producer, ok := c.listeners[producerID] c.mu.RUnlock() if ok { producer.ReceivedSendReceipt(response) } else { c.log. WithField("producerID", producerID). Warn("Got unexpected send receipt for messageID=%+v", response.MessageId) } } func (c *connection) handleMessage(response *pb.CommandMessage, payload Buffer) { c.log.Debug("Got Message: ", response) consumerID := response.GetConsumerId() if consumer, ok := c.consumerHandler(consumerID); ok { err := consumer.MessageReceived(response, payload) if err != nil { c.log. WithError(err). WithField("consumerID", consumerID). Error("handle message Id: ", response.MessageId) } } else { c.log.WithField("consumerID", consumerID).Warn("Got unexpected message: ", response.MessageId) } } func (c *connection) deletePendingRequest(requestID uint64) (*request, bool) { c.pendingLock.Lock() defer c.pendingLock.Unlock() request, ok := c.pendingReqs[requestID] if ok { delete(c.pendingReqs, requestID) } return request, ok } func (c *connection) findPendingRequest(requestID uint64) (*request, bool) { c.pendingLock.Lock() defer c.pendingLock.Unlock() request, ok := c.pendingReqs[requestID] return request, ok } func (c *connection) failPendingRequests(err error) bool { c.pendingLock.Lock() defer c.pendingLock.Unlock() for id, req := range c.pendingReqs { req.callback(nil, err) delete(c.pendingReqs, id) } return true } func (c *connection) lastDataReceived() time.Time { c.lastDataReceivedLock.Lock() defer c.lastDataReceivedLock.Unlock() t := c.lastDataReceivedTime return t } func (c *connection) setLastDataReceived(t time.Time) { c.lastDataReceivedLock.Lock() defer c.lastDataReceivedLock.Unlock() c.lastDataReceivedTime = t } func (c *connection) sendPing() { c.log.Debug("Sending PING") c.writeCommand(baseCommand(pb.BaseCommand_PING, &pb.CommandPing{})) } func (c *connection) handlePong() { c.log.Debug("Received PONG response") } func (c *connection) handlePing() { c.log.Debug("Responding to PING request") c.writeCommand(baseCommand(pb.BaseCommand_PONG, &pb.CommandPong{})) } func (c *connection) handleAuthChallenge(authChallenge *pb.CommandAuthChallenge) { c.log.Debugf("Received auth challenge from broker: %s", authChallenge.GetChallenge().GetAuthMethodName()) // Get new credentials from the provider authData, err := c.auth.GetData() if err != nil { c.log.WithError(err).Warn("Failed to load auth credentials") c.Close() return } // Brokers expect authData to be not nil if authData == nil { authData = []byte{} } cmdAuthResponse := &pb.CommandAuthResponse{ ProtocolVersion: proto.Int32(PulsarProtocolVersion), ClientVersion: proto.String(c.getClientVersion()), Response: &pb.AuthData{ AuthMethodName: proto.String(c.auth.Name()), AuthData: authData, }, } c.writeCommand(baseCommand(pb.BaseCommand_AUTH_RESPONSE, cmdAuthResponse)) } func (c *connection) handleSendError(sendError *pb.CommandSendError) { c.log.Warnf("Received send error from server: [%v] : [%s]", sendError.GetError(), sendError.GetMessage()) producerID := sendError.GetProducerId() switch sendError.GetError() { case pb.ServerError_NotAllowedError: _, ok := c.deletePendingProducers(producerID) if !ok { c.log.Warnf("Received unexpected error response for request %d of type %s", producerID, sendError.GetError()) return } c.log.Warnf("server error: %s: %s", sendError.GetError(), sendError.GetMessage()) case pb.ServerError_TopicTerminatedError: _, ok := c.deletePendingProducers(producerID) if !ok { c.log.Warnf("Received unexpected error response for producer %d of type %s", producerID, sendError.GetError()) return } c.log.Warnf("server error: %s: %s", sendError.GetError(), sendError.GetMessage()) default: // By default, for transient error, let the reconnection logic // to take place and re-establish the produce again c.Close() } } func (c *connection) deletePendingProducers(producerID uint64) (ConnectionListener, bool) { c.mu.Lock() defer c.mu.Unlock() producer, ok := c.listeners[producerID] if ok { delete(c.listeners, producerID) } return producer, ok } func (c *connection) handleCloseConsumer(closeConsumer *pb.CommandCloseConsumer) { consumerID := closeConsumer.GetConsumerId() c.log.Infof("Broker notification of Closed consumer: %d", consumerID) if consumer, ok := c.consumerHandler(consumerID); ok { consumer.ConnectionClosed(closeConsumer) c.DeleteConsumeHandler(consumerID) } else { c.log.WithField("consumerID", consumerID).Warnf("Consumer with ID not found while closing consumer") } } func (c *connection) handleActiveConsumerChange(consumerChange *pb.CommandActiveConsumerChange) { consumerID := consumerChange.GetConsumerId() isActive := consumerChange.GetIsActive() if consumer, ok := c.consumerHandler(consumerID); ok { consumer.ActiveConsumerChanged(isActive) } else { c.log.WithField("consumerID", consumerID).Warnf("Consumer not found while active consumer change") } } func (c *connection) handleCloseProducer(closeProducer *pb.CommandCloseProducer) { c.log.Infof("Broker notification of Closed producer: %d", closeProducer.GetProducerId()) producerID := closeProducer.GetProducerId() producer, ok := c.deletePendingProducers(producerID) // did we find a producer? if ok { producer.ConnectionClosed(closeProducer) } else { c.log.WithField("producerID", producerID).Warn("Producer with ID not found while closing producer") } } func (c *connection) getMigratedBrokerServiceURL(commandTopicMigrated *pb.CommandTopicMigrated) string { if c.tlsOptions == nil { if commandTopicMigrated.GetBrokerServiceUrl() != "" { return commandTopicMigrated.GetBrokerServiceUrl() } } else if commandTopicMigrated.GetBrokerServiceUrlTls() != "" { return commandTopicMigrated.GetBrokerServiceUrlTls() } return "" } func (c *connection) handleTopicMigrated(commandTopicMigrated *pb.CommandTopicMigrated) { resourceID := commandTopicMigrated.GetResourceId() migratedBrokerServiceURL := c.getMigratedBrokerServiceURL(commandTopicMigrated) if migratedBrokerServiceURL == "" { c.log.Warnf("Failed to find the migrated broker url for resource: %d, migratedBrokerUrl: %s, migratedBrokerUrlTls:%s", resourceID, commandTopicMigrated.GetBrokerServiceUrl(), commandTopicMigrated.GetBrokerServiceUrlTls()) return } if commandTopicMigrated.GetResourceType() == pb.CommandTopicMigrated_Producer { c.mu.RLock() producer, ok := c.listeners[resourceID] c.mu.RUnlock() if ok { producer.SetRedirectedClusterURI(migratedBrokerServiceURL) c.log.Infof("producerID:{%d} migrated to RedirectedClusterURI:{%s}", resourceID, migratedBrokerServiceURL) } else { c.log.WithField("producerID", resourceID).Warn("Failed to SetRedirectedClusterURI") } } else { consumer, ok := c.consumerHandler(resourceID) if ok { consumer.SetRedirectedClusterURI(migratedBrokerServiceURL) c.log.Infof("consumerID:{%d} migrated to RedirectedClusterURI:{%s}", resourceID, migratedBrokerServiceURL) } else { c.log.WithField("consumerID", resourceID).Warn("Failed to SetRedirectedClusterURI") } } } func (c *connection) RegisterListener(id uint64, listener ConnectionListener) error { c.mu.Lock() defer c.mu.Unlock() if c.getState() == connectionClosed { c.log.Warnf("Connection closed unable register listener id=%+v", id) return errUnableRegisterListener } c.listeners[id] = listener return nil } func (c *connection) UnregisterListener(id uint64) { c.mu.Lock() defer c.mu.Unlock() delete(c.listeners, id) } func (c *connection) ResetLastActive() { c.mu.Lock() defer c.mu.Unlock() c.lastActive = time.Now() } func (c *connection) isIdle() bool { return len(c.listeners) == 0 && len(c.consumerHandlers) == 0 && len(c.incomingRequestsCh) == 0 && len(c.writeRequestsCh) == 0 } func (c *connection) CheckIdle(maxIdleTime time.Duration) bool { c.pendingLock.Lock() sizePendingReqs := len(c.pendingReqs) c.pendingLock.Unlock() c.mu.Lock() defer c.mu.Unlock() if sizePendingReqs != 0 || !c.isIdle() { c.lastActive = time.Now() } return time.Since(c.lastActive) > maxIdleTime } func (c *connection) WaitForClose() <-chan struct{} { return c.closeCh } // Close closes the connection by // closing underlying socket connection and closeCh. // This also triggers callbacks to the ConnectionClosed listeners. func (c *connection) Close() { c.closeOnce.Do(func() { listeners, consumerHandlers, cnx := c.closeAndEmptyObservers() if cnx != nil { _ = cnx.Close() } close(c.closeCh) // notify producers connection closed for _, listener := range listeners { listener.ConnectionClosed(nil) } // notify consumers connection closed for _, handler := range consumerHandlers { handler.ConnectionClosed(nil) } c.metrics.ConnectionsClosed.Inc() }) } func (c *connection) closeAndEmptyObservers() ([]ConnectionListener, []ConsumerHandler, net.Conn) { c.mu.Lock() defer c.mu.Unlock() c.setStateClosed() listeners := make([]ConnectionListener, 0, len(c.listeners)) for _, listener := range c.listeners { listeners = append(listeners, listener) } handlers := make([]ConsumerHandler, 0, len(c.consumerHandlers)) for _, handler := range c.consumerHandlers { handlers = append(handlers, handler) } return listeners, handlers, c.cnx } func (c *connection) getState() connectionState { return connectionState(c.state.Load()) } func (c *connection) setStateReady() { c.state.CompareAndSwap(int32(connectionInit), int32(connectionReady)) } func (c *connection) setStateClosed() { c.state.Store(int32(connectionClosed)) } func (c *connection) closed() bool { return connectionClosed == c.getState() } func (c *connection) getTLSConfig() (*tls.Config, error) { if c.tlsOptions.TLSConfig != nil { return c.tlsOptions.TLSConfig, nil } tlsConfig := &tls.Config{ InsecureSkipVerify: c.tlsOptions.AllowInsecureConnection, CipherSuites: c.tlsOptions.CipherSuites, MinVersion: c.tlsOptions.MinVersion, MaxVersion: c.tlsOptions.MaxVersion, } if c.tlsOptions.TrustCertsFilePath != "" { caCerts, err := os.ReadFile(c.tlsOptions.TrustCertsFilePath) if err != nil { return nil, err } tlsConfig.RootCAs = x509.NewCertPool() ok := tlsConfig.RootCAs.AppendCertsFromPEM(caCerts) if !ok { return nil, errors.New("failed to parse root CAs certificates") } } if c.tlsOptions.ValidateHostname { if c.tlsOptions.ServerName != "" { tlsConfig.ServerName = c.tlsOptions.ServerName } else { tlsConfig.ServerName = c.physicalAddr.Hostname() } c.log.Debugf("getTLSConfig(): setting tlsConfig.ServerName = %+v", tlsConfig.ServerName) } if c.tlsOptions.CertFile != "" && c.tlsOptions.KeyFile != "" { cert, err := tls.LoadX509KeyPair(c.tlsOptions.CertFile, c.tlsOptions.KeyFile) if err != nil { return nil, errors.New(err.Error()) } tlsConfig.Certificates = []tls.Certificate{cert} } cert, err := c.auth.GetTLSCertificate() if err != nil { return nil, err } if cert != nil { tlsConfig.Certificates = []tls.Certificate{*cert} } return tlsConfig, nil } func (c *connection) AddConsumeHandler(id uint64, handler ConsumerHandler) error { c.mu.Lock() defer c.mu.Unlock() if c.getState() == connectionClosed { c.log.Warnf("Closed connection unable add consumer with id=%+v", id) return errUnableAddConsumeHandler } c.consumerHandlers[id] = handler return nil } func (c *connection) DeleteConsumeHandler(id uint64) { c.mu.Lock() defer c.mu.Unlock() delete(c.consumerHandlers, id) } func (c *connection) consumerHandler(id uint64) (ConsumerHandler, bool) { c.mu.RLock() defer c.mu.RUnlock() h, ok := c.consumerHandlers[id] return h, ok } func (c *connection) ID() string { c.mu.RLock() defer c.mu.RUnlock() return fmt.Sprintf("%s -> %s", c.cnx.LocalAddr(), c.cnx.RemoteAddr()) } func (c *connection) GetMaxMessageSize() int32 { return c.maxMessageSize }