go/protocol/listener.go (190 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package protocol import ( "context" "log/slog" "strings" "sync/atomic" "github.com/Azure/iot-operations-sdks/go/internal/log" "github.com/Azure/iot-operations-sdks/go/internal/mqtt" "github.com/Azure/iot-operations-sdks/go/protocol/errors" "github.com/Azure/iot-operations-sdks/go/protocol/internal" "github.com/Azure/iot-operations-sdks/go/protocol/internal/constants" "github.com/Azure/iot-operations-sdks/go/protocol/internal/errutil" "github.com/Azure/iot-operations-sdks/go/protocol/internal/version" "github.com/google/uuid" ) type ( // Listener represents an object which will listen to a MQTT topic. Listener interface { Start(context.Context) error Close() } // Listeners represents a collection of MQTT listeners. Listeners []Listener // Provide the shared implementation details for the MQTT listeners. listener[T any] struct { app *Application client MqttClient encoding Encoding[T] topic *internal.TopicFilter shareName string concurrency uint reqCorrelation bool supportedVersion []int log log.Logger handler interface { onMsg(context.Context, *mqtt.Message, *Message[T]) error onErr(context.Context, *mqtt.Message, error) error } done func() active atomic.Bool } message[T any] struct { Mqtt *mqtt.Message Message[T] } ) func (l *listener[T]) register() { handle, stop := internal.Concurrent(l.concurrency, l.handle) done := l.client.RegisterMessageHandler( func(ctx context.Context, m *mqtt.Message) { msg := &message[T]{Mqtt: m} var match bool msg.TopicTokens, match = l.topic.Tokens(m.Topic) if match { handle(ctx, msg) } else { m.Ack() } }, ) l.done = func() { done() stop() } } func (l *listener[T]) filter() string { // Make the subscription shared if specified. if l.shareName != "" { return "$share/" + l.shareName + "/" + l.topic.Filter() } return l.topic.Filter() } func (l *listener[T]) start(ctx context.Context, name string) error { if l.active.CompareAndSwap(false, true) { filter := l.filter() ack, err := l.client.Subscribe( ctx, filter, mqtt.WithQoS(1), mqtt.WithNoLocal(l.shareName == ""), ) if err := errutil.Mqtt(ctx, "subscribe", ack, err); err != nil { l.log.Warn(ctx, err) return err } l.log.Info(ctx, name+" started", slog.String("topic", filter)) } return nil } func (l *listener[T]) close(name string) { ctx := context.Background() if l.active.CompareAndSwap(true, false) { filter := l.filter() if ack, err := l.client.Unsubscribe(ctx, filter); err != nil { // Returning an error from a close function that is most likely to // be deferred is rarely useful, so just log it. l.log.Error(ctx, errutil.Mqtt(ctx, "unsubscribe", ack, err)) } l.log.Info(ctx, name+" closed", slog.String("topic", filter)) } l.done() } func (l *listener[T]) handle(ctx context.Context, msg *message[T]) { // The very first check must be the version, because if we don't support it, // nothing else is trustworthy. ver := msg.Mqtt.UserProperties[constants.ProtocolVersion] if !version.IsSupported(ver, l.supportedVersion) { l.error(ctx, msg.Mqtt, &errors.Remote{ Message: "request version not supported", Kind: errors.UnsupportedVersion{ ProtocolVersion: ver, SupportedMajorProtocolVersions: l.supportedVersion, }, }) return } msg.ClientID = msg.Mqtt.UserProperties[constants.SourceID] if l.reqCorrelation && len(msg.Mqtt.CorrelationData) == 0 { l.error(ctx, msg.Mqtt, &errors.Remote{ Message: "correlation data missing", Kind: errors.HeaderMissing{ HeaderName: constants.CorrelationData, }, }) return } if len(msg.Mqtt.CorrelationData) != 0 { correlationData, err := uuid.FromBytes(msg.Mqtt.CorrelationData) if err != nil { l.error(ctx, msg.Mqtt, &errors.Remote{ Message: "correlation data is not a valid UUID", Kind: errors.HeaderInvalid{ HeaderName: constants.CorrelationData, }, }) return } msg.CorrelationData = correlationData.String() } ts := msg.Mqtt.UserProperties[constants.Timestamp] if ts != "" { var err error msg.Timestamp, err = l.app.hlc.Parse(constants.Timestamp, ts) if err != nil { l.error(ctx, msg.Mqtt, &errors.Remote{ Message: "timestamp is not a valid RFC3339 timestamp", Kind: errors.HeaderInvalid{ HeaderName: constants.Timestamp, HeaderValue: ts, }, }) return } if err = l.app.hlc.Set(msg.Timestamp); err != nil { l.error(ctx, msg.Mqtt, err) return } } msg.Metadata = make(map[string]string, len(msg.Mqtt.UserProperties)) for key, val := range msg.Mqtt.UserProperties { if !strings.HasPrefix(key, constants.Protocol) { msg.Metadata[key] = val } } msg.Data = &Data{ msg.Mqtt.Payload, msg.Mqtt.ContentType, msg.Mqtt.PayloadFormat, } if err := l.handler.onMsg(ctx, msg.Mqtt, &msg.Message); err != nil { l.error(ctx, msg.Mqtt, err) return } } // Handle payload manually, since it may be ignored on errors. func (l *listener[T]) payload(msg *Message[T]) (T, error) { return deserialize(l.encoding, msg.Data) } func (l *listener[T]) error(ctx context.Context, pub *mqtt.Message, err error) { // Drop the message if the error handler fails. if e := l.handler.onErr(ctx, pub, err); e != nil { l.drop(ctx, pub, err) } } func (l *listener[T]) drop(ctx context.Context, _ *mqtt.Message, err error) { // Log dropped messages as an error, because we have no other way of // communicating this to the user. l.log.Error(ctx, err) } // Start listening to all underlying MQTT topics. func (ls Listeners) Start(ctx context.Context) error { for _, l := range ls { if err := l.Start(ctx); err != nil { return err } } return nil } // Close all underlying MQTT topics and free resources. func (ls Listeners) Close() { for _, l := range ls { l.Close() } }