go/mqtt/connect.go (278 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package mqtt
import (
"context"
"io"
"log/slog"
"math"
"github.com/Azure/iot-operations-sdks/go/mqtt/internal"
"github.com/eclipse/paho.golang/paho"
)
// RegisterConnectEventHandler registers a handler to a list of handlers that
// are called synchronously in registration order whenever the session client
// successfully establishes an MQTT connection. Note that since the handler
// gets called synchronously, handlers should not block for an extended period
// of time to avoid blocking the session client.
func (c *SessionClient) RegisterConnectEventHandler(
handler ConnectEventHandler,
) func() {
return c.connectEventHandlers.AppendEntry(handler)
}
// RegisterDisconnectEventHandler registers a handler to a list of handlers that
// are called synchronously in registration order whenever the session client
// detects a disconnection from the MQTT server. Note that since the handler
// gets called synchronously, handlers should not block for an extended period
// of time to avoid blocking the session client.
func (c *SessionClient) RegisterDisconnectEventHandler(
handler DisconnectEventHandler,
) func() {
return c.disconnectEventHandlers.AppendEntry(handler)
}
// RegisterFatalErrorHandler registers a handler that is called in a goroutine
// if the session client terminates due to a fatal error.
func (c *SessionClient) RegisterFatalErrorHandler(
handler func(error),
) func() {
return c.fatalErrorHandlers.AppendEntry(handler)
}
// Start the session client, spawning any necessary background goroutines. In
// order to terminate the session client and clean up any running goroutines,
// Stop() must be called after calling Start().
func (c *SessionClient) Start() error {
if !c.sessionStarted.CompareAndSwap(false, true) {
return &ClientStateError{State: Started}
}
c.shutdown = internal.NewBackground(&ClientStateError{State: ShutDown})
ctx, _ := c.shutdown.With(context.Background())
go func() {
defer c.shutdown.Close()
if err := c.manageConnection(ctx); err != nil {
c.log.Error(ctx, err)
for handler := range c.fatalErrorHandlers.All() {
go handler(err)
}
}
}()
go c.manageOutgoingPublishes(ctx)
return nil
}
// Stop the session client, terminating any pending operations and cleaning up
// background goroutines.
func (c *SessionClient) Stop() error {
if !c.sessionStarted.Load() {
return &ClientStateError{State: NotStarted}
}
c.shutdown.Close()
return nil
}
// Attempts an initial connection and then listens for disconnections to attempt
// reconnections. Blocks until the ctx is cancelled or the connection can no
// longer be maintained (due to a fatal error or retry policy exhaustion).
func (c *SessionClient) manageConnection(ctx context.Context) error {
defer c.cleanup(ctx)
var reconnect bool
for {
var connack *paho.Connack
err := c.options.ConnectionRetry.Start(ctx, "connect",
func(ctx context.Context) (bool, error) {
var err error
connCtx := ctx
if c.options.ConnectionTimeout > 0 {
var cancel func()
connCtx, cancel = context.WithTimeout(
ctx,
c.options.ConnectionTimeout,
)
defer cancel()
}
connack, err = c.connect(connCtx, reconnect)
// Decide to retry depending on whether we consider this error
// to be fatal. We don't wrap these errors, so we can use a
// simple type-switch instead of Go error wrapping.
switch err.(type) {
case *InvalidArgumentError,
*SessionLostError,
*FatalConnackError,
*FatalDisconnectError:
return false, err
default:
return true, err
}
},
)
if err != nil {
return err
}
// NOTE: signalConnection and signalDisconnection must only be called
// together in this loop to ensure ordering between the two.
c.signalConnection(ctx, &ConnectEvent{ReasonCode: connack.ReasonCode})
reconnect = true
select {
case <-c.conn.Current().Down.Done():
// Current paho instance got disconnected.
switch err := c.conn.Current().Error.(type) {
case *FatalDisconnectError:
c.signalDisconnection(ctx, &DisconnectEvent{
ReasonCode: &err.ReasonCode,
})
return err
case *DisconnectError:
c.signalDisconnection(ctx, &DisconnectEvent{
ReasonCode: &err.ReasonCode,
})
default:
c.signalDisconnection(ctx, &DisconnectEvent{
Error: err,
})
}
case <-ctx.Done():
// Session client is shutting down.
return nil
}
// If we get here, a reconnection will be attempted.
}
}
// Create an instance of a Paho client and attempts to connect it to the MQTT
// server. If the client is successfully connected, return a channel which will
// be notified when the connection on that client instance goes down, and
// whether or not that disconnection is due to a fatal error.
func (c *SessionClient) connect(
ctx context.Context,
reconnect bool,
) (*paho.Connack, error) {
attempt := c.conn.Attempt()
conn, err := c.connectionProvider(ctx)
if err != nil {
return nil, err
}
var auther paho.Auther
if c.options.Auth != nil {
auther = &pahoAuther{c}
}
pahoClient := paho.NewClient(paho.ClientConfig{
ClientID: c.clientID,
Session: c.session,
Conn: conn,
AuthHandler: auther,
// Set Paho's packet timeout to the maximum possible value to
// effectively disable it. We can still control any timeouts through the
// contexts we pass into Paho.
PacketTimeout: math.MaxInt64,
// Disable automatic acking in Paho. The session client will manage acks
// instead.
EnableManualAcknowledgment: true,
OnPublishReceived: []func(paho.PublishReceived) (bool, error){
// Add 1 to the conn count for this because this listener is
// effective AFTER the connection succeeds.
c.makeOnPublishReceived(ctx, attempt),
},
OnServerDisconnect: func(d *paho.Disconnect) {
if isFatalDisconnectReasonCode(d.ReasonCode) {
c.conn.Disconnect(attempt, &FatalDisconnectError{d.ReasonCode})
} else {
c.conn.Disconnect(attempt, &DisconnectError{d.ReasonCode})
}
},
OnClientError: func(err error) {
c.conn.Disconnect(attempt, err)
},
})
connect, err := c.buildConnectPacket(ctx, reconnect)
if err != nil {
return nil, err
}
c.log.Packet(ctx, "connect", connect)
connack, err := pahoClient.Connect(ctx, connect)
c.log.Packet(ctx, "connack", connack)
switch {
case connack == nil:
// This assumes that all errors returned by Paho's connect method
// without a CONNACK are non-fatal.
return nil, err
case isFatalConnackReasonCode(connack.ReasonCode):
return nil, &FatalConnackError{connack.ReasonCode}
case connack.ReasonCode >= 80:
return nil, &ConnackError{connack.ReasonCode}
case reconnect && !connack.SessionPresent:
c.forceDisconnect(ctx, pahoClient)
return nil, &SessionLostError{}
default:
if err := c.conn.Connect(pahoClient); err != nil {
return nil, err
}
if c.options.Auth != nil && connack.Properties.AuthMethod == "" {
// Ensure the auth provider is notified of success even if the MQTT
// server fails to echo the auth method.
c.options.Auth.AuthSuccess(c.requestReauth)
}
return connack, nil
}
}
func (c *SessionClient) signalConnection(
ctx context.Context,
event *ConnectEvent,
) {
c.log.Info(ctx, "connected",
slog.Int("reason_code", int(event.ReasonCode)),
)
for handler := range c.connectEventHandlers.All() {
handler(event)
}
}
func (c *SessionClient) signalDisconnection(
ctx context.Context,
event *DisconnectEvent,
) {
switch {
case event.ReasonCode != nil:
c.log.Warn(ctx, "disconnected",
slog.Int("reason_code", int(*event.ReasonCode)),
)
case event.Error != nil:
c.log.Warn(ctx, "disconnected",
slog.String("error", event.Error.Error()),
)
default:
c.log.Warn(ctx, "disconnected")
}
for handler := range c.disconnectEventHandlers.All() {
handler(event)
}
}
func (c *SessionClient) forceDisconnect(
ctx context.Context,
client *paho.Client,
) {
immediateSessionExpiry := uint32(0)
disconn := &paho.Disconnect{
ReasonCode: disconnectNormalDisconnection,
Properties: &paho.DisconnectProperties{
SessionExpiryInterval: &immediateSessionExpiry,
},
}
c.log.Packet(ctx, "disconnect", disconn)
_ = client.Disconnect(disconn)
}
func (c *SessionClient) cleanup(ctx context.Context) {
// Send a DISCONNECT packet if possible and signal disconnection if needed.
if pahoClient := c.conn.Current().Client; pahoClient != nil {
c.forceDisconnect(ctx, pahoClient)
c.signalDisconnection(ctx, &DisconnectEvent{})
}
// If the auth provider has cleanup to do, do so now.
if closer, ok := c.options.Auth.(io.Closer); ok {
if err := closer.Close(); err != nil {
c.log.Error(ctx, err)
}
}
}
func (c *SessionClient) buildConnectPacket(
ctx context.Context,
reconnect bool,
) (*paho.Connect, error) {
packet := &paho.Connect{
ClientID: c.clientID,
CleanStart: !reconnect && c.options.CleanStart,
KeepAlive: c.options.KeepAlive,
Properties: &paho.ConnectProperties{
SessionExpiryInterval: &c.options.SessionExpiry,
ReceiveMaximum: &c.options.ReceiveMaximum,
RequestProblemInfo: true,
User: internal.MapToUserProperties(
c.options.ConnectUserProperties,
),
},
}
if c.options.Username != nil {
username, usernameFlag, err := c.options.Username(ctx)
if err != nil {
return nil, &InvalidArgumentError{
message: "error getting username",
wrapped: err,
}
}
if usernameFlag {
packet.UsernameFlag = true
packet.Username = username
}
}
if c.options.Password != nil {
password, passwordFlag, err := c.options.Password(ctx)
if err != nil {
return nil, &InvalidArgumentError{
message: "error getting password",
wrapped: err,
}
}
if passwordFlag {
packet.PasswordFlag = true
packet.Password = password
}
}
if c.options.Auth != nil {
authValues, err := c.options.Auth.InitiateAuth(false)
if err != nil {
return nil, &InvalidArgumentError{
message: "error getting auth values",
wrapped: err,
}
}
packet.Properties.AuthData = authValues.AuthData
packet.Properties.AuthMethod = authValues.AuthMethod
}
return packet, nil
}