go/mqtt/internal/connection_tracker.go (96 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package internal import ( "context" "iter" "sync" ) type ( // Struct to track the connection state of the client, and retreive the // currently connected client. ConnectionTracker[Client comparable] struct { current CurrentConnection[Client] currentMu sync.RWMutex } // Mutex-protected connection data. CurrentConnection[Client comparable] struct { // Current instance of the client. Client Client // Error that caused the last disconnection. Error error // Channel that is closed when the connection is up (i.e., a new client // instance is created and connected to the server with a successful // CONNACK), used to notify goroutines that are waiting on a connection // to be re-established. up chan struct{} // Background state that is stopped when the the connection is down. // Used to notify goroutines that expect the connection to go down that // the manageConnection() goroutine has detected the disconnection and // is attempting to start a new connection. Down *Background // Counter for the current connection attempt. This is independent from // the client, since it also records unsuccessful connect attempts. Attempt uint64 } ) func NewConnectionTracker[Client comparable]() *ConnectionTracker[Client] { c := &ConnectionTracker[Client]{} c.current.up = make(chan struct{}) c.current.Down = NewBackground(context.Canceled) // Immediately close down to maintain the invariant that down is closed iff // the client is disconnected. c.current.Down.Close() return c } func (c *ConnectionTracker[Client]) Attempt() uint64 { c.currentMu.Lock() defer c.currentMu.Unlock() c.current.Error = nil c.current.Attempt++ return c.current.Attempt } func (c *ConnectionTracker[Client]) Connect(client Client) error { c.currentMu.Lock() defer c.currentMu.Unlock() // A disconnect was encountered between attempt and connect. // Don't connect and return the error. if c.current.Error != nil { return c.current.Error } c.current.Client = client close(c.current.up) c.current.Down = NewBackground(context.Canceled) return nil } func (c *ConnectionTracker[Client]) Disconnect(attempt uint64, err error) { c.currentMu.Lock() defer c.currentMu.Unlock() // This disconnect is for another attempt; don't change state. if c.current.Attempt != attempt { return } // Record the error if there isn't already one recorded. if c.current.Error == nil { c.current.Error = err } // An error was encountered before connect. Record it but don't disconnect. var zero Client if c.current.Client == zero { return } c.current.Client = zero c.current.up = make(chan struct{}) c.current.Down.Close() } func (c *ConnectionTracker[Client]) Current() CurrentConnection[Client] { c.currentMu.RLock() defer c.currentMu.RUnlock() return c.current } // Return the client for the current connection. Since the client gets replaced // when the we reconnect, this is represented as an iterator. The caller should // return from the loop once the call they're trying to make is complete, or // continue the loop if we need to reconnect and try again. The loop will only // terminate on its own via the context. It also provides a context which will // be closed if the client disconnects, in order to terminate any requests. func (c *ConnectionTracker[Client]) Client( ctx context.Context, ) iter.Seq2[context.Context, Client] { return func(yield func(context.Context, Client) bool) { for { current := c.Current() var zero Client if current.Client == zero { select { case <-ctx.Done(): return case <-current.up: continue } } if !func() bool { ctx, cancel := current.Down.With(ctx) defer cancel() return yield(ctx, current.Client) }() { return } // If we get here, the request failed because the connection is down // or because ctx was cancelled. select { case <-ctx.Done(): return case <-current.Down.Done(): // Connection is down, wait for the connection to come back up // and retry. } } } }