go/protocol/command_executor.go (359 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package protocol import ( "context" "fmt" "log/slog" "maps" "time" "github.com/Azure/iot-operations-sdks/go/internal/log" "github.com/Azure/iot-operations-sdks/go/internal/mqtt" "github.com/Azure/iot-operations-sdks/go/internal/options" "github.com/Azure/iot-operations-sdks/go/internal/wallclock" "github.com/Azure/iot-operations-sdks/go/protocol/errors" "github.com/Azure/iot-operations-sdks/go/protocol/internal" "github.com/Azure/iot-operations-sdks/go/protocol/internal/caching" "github.com/Azure/iot-operations-sdks/go/protocol/internal/constants" "github.com/Azure/iot-operations-sdks/go/protocol/internal/errutil" "github.com/Azure/iot-operations-sdks/go/protocol/internal/version" ) type ( // CommandExecutor provides the ability to execute a single command. CommandExecutor[Req any, Res any] struct { listener *listener[Req] publisher *publisher[Res] handler CommandHandler[Req, Res] timeout *internal.Timeout cache *caching.Cache log log.Logger } // CommandExecutorOption represents a single command executor option. CommandExecutorOption interface{ commandExecutor(*CommandExecutorOptions) } // CommandExecutorOptions are the resolved command executor options. CommandExecutorOptions struct { Idempotent bool Concurrency uint Timeout time.Duration ShareName string TopicNamespace string TopicTokens map[string]string Logger *slog.Logger } // CommandHandler is the user-provided implementation of a single command // execution. It is treated as blocking; all parallelism is handled by the // library. This *must* be thread-safe. CommandHandler[Req any, Res any] = func( context.Context, *CommandRequest[Req], ) (*CommandResponse[Res], error) // CommandRequest contains per-message data and methods that are exposed to // the command handlers. CommandRequest[Req any] struct { Message[Req] } // CommandResponse contains per-message data and methods that are returned // by the command handlers. CommandResponse[Res any] struct { Message[Res] } // WithIdempotent marks the command as idempotent. WithIdempotent bool // RespondOption represent a single per-response option. RespondOption interface{ respond(*RespondOptions) } // RespondOptions are the resolved per-response options. RespondOptions struct { Metadata map[string]string } ) const ( commandExecutorComponentName = "command executor" commandExecutorErrStr = "command execution" ) // NewCommandExecutor creates a new command executor. func NewCommandExecutor[Req, Res any]( app *Application, client MqttClient, requestEncoding Encoding[Req], responseEncoding Encoding[Res], requestTopicPattern string, handler CommandHandler[Req, Res], opt ...CommandExecutorOption, ) (ce *CommandExecutor[Req, Res], err error) { var opts CommandExecutorOptions opts.Apply(opt) logger := log.Wrap(opts.Logger, app.log) defer func() { err = errutil.Return(err, logger, true) }() if err := errutil.ValidateNonNil(map[string]any{ "client": client, "requestEncoding": requestEncoding, "responseEncoding": responseEncoding, "handler": handler, }); err != nil { return nil, err } to := &internal.Timeout{ Duration: opts.Timeout, Name: "ExecutionTimeout", Text: commandExecutorErrStr, } if err := to.Validate(); err != nil { return nil, err } if err := internal.ValidateShareName(opts.ShareName); err != nil { return nil, err } reqTP, err := internal.NewTopicPattern( "requestTopicPattern", requestTopicPattern, opts.TopicTokens, opts.TopicNamespace, ) if err != nil { return nil, err } reqTF, err := reqTP.Filter() if err != nil { return nil, err } ce = &CommandExecutor[Req, Res]{ handler: handler, timeout: to, cache: caching.New(wallclock.Instance), log: logger, } ce.listener = &listener[Req]{ app: app, client: client, encoding: requestEncoding, topic: reqTF, shareName: opts.ShareName, concurrency: opts.Concurrency, reqCorrelation: true, supportedVersion: version.RPCSupported, log: logger, handler: ce, } ce.publisher = &publisher[Res]{ app: app, client: client, encoding: responseEncoding, version: version.RPC, } ce.listener.register() return ce, nil } // Start listening to the MQTT request topic. func (ce *CommandExecutor[Req, Res]) Start(ctx context.Context) error { return ce.listener.start(ctx, commandExecutorComponentName) } // Close the command executor to free its resources. func (ce *CommandExecutor[Req, Res]) Close() { ce.listener.close(commandExecutorComponentName) } func (ce *CommandExecutor[Req, Res]) onMsg( ctx context.Context, pub *mqtt.Message, msg *Message[Req], ) error { ce.log.Debug(ctx, "request received", slog.String("topic", pub.Topic), slog.Any("correlation_data", pub.CorrelationData), ) if err := ignoreRequest(pub); err != nil { return err } if pub.MessageExpiry == 0 { return &errors.Remote{ Message: "message expiry missing", Kind: errors.HeaderMissing{ HeaderName: constants.MessageExpiry, }, } } rpub, err := ce.cache.Exec(pub, func() (*mqtt.Message, error) { req := &CommandRequest[Req]{Message: *msg} var err error req.Payload, err = ce.listener.payload(msg) if err != nil { return nil, err } handlerCtx, cancel := ce.timeout.Context(ctx) defer cancel() handlerCtx, cancel = pubTimeout(pub).Context(handlerCtx) defer cancel() res, err := ce.handle(handlerCtx, req) if err != nil { return nil, err } rpub, err := ce.build(pub, res, nil) if err != nil { return nil, err } return rpub, nil }) if err != nil { return err } defer ce.ack(ctx, pub) if rpub == nil { return nil } if err = ce.publisher.publish(ctx, rpub); err != nil { // If the publish fails onErr will also fail, so just drop the message. ce.listener.drop(ctx, pub, err) } else { ce.log.Debug(ctx, "response sent", slog.String("topic", rpub.Topic), slog.Any("correlation_data", rpub.CorrelationData), ) } return nil } func (ce *CommandExecutor[Req, Res]) onErr( ctx context.Context, pub *mqtt.Message, err error, ) error { defer ce.ack(ctx, pub) if e := ignoreRequest(pub); e != nil { return e } // If the error is a no-return error, don't send it. if no, e := errutil.IsNoReturn(err); no { return e } rpub, e := ce.build(pub, nil, err) if e != nil { return e } if e := ce.publisher.publish(ctx, rpub); e != nil { return e } // We successfully returned the error in the response, so just log it as a // warning. ce.log.Warn(ctx, err) return nil } // Call handler with panic catch. func (ce *CommandExecutor[Req, Res]) handle( ctx context.Context, req *CommandRequest[Req], ) (*CommandResponse[Res], error) { rchan := make(chan commandReturn[Res]) // TODO: This goroutine will leak if the handler blocks without respecting // the context. This is a known limitation to align to the C# behavior, and // should be changed if that behavior is revisited. go func() { var ret commandReturn[Res] defer func() { if ePanic := recover(); ePanic != nil { ret.err = &errors.Remote{ Message: fmt.Sprint(ePanic), Kind: errors.ExecutionError{}, } } select { case rchan <- ret: case <-ctx.Done(): } }() ret.res, ret.err = ce.handler(ctx, req) if e := errutil.Context(ctx, commandExecutorErrStr); e != nil { // An error from the context overrides any return value. ret.err = e } else if ret.err != nil { ret.err = &errors.Remote{ Message: ret.err.Error(), Kind: errors.ExecutionError{}, } } else if ret.res == nil { ret.err = &errors.Remote{ Message: "command handler returned no response", Kind: errors.ExecutionError{}, } } }() select { case ret := <-rchan: return ret.res, ret.err case <-ctx.Done(): return nil, errutil.Context(ctx, commandExecutorErrStr) } } // Build the response publish packet. func (ce *CommandExecutor[Req, Res]) build( pub *mqtt.Message, res *CommandResponse[Res], resErr error, ) (*mqtt.Message, error) { var msg *Message[Res] if res != nil { msg = &res.Message } rpub, err := ce.publisher.build(msg, nil, pubTimeout(pub)) if err != nil { return nil, err } rpub.CorrelationData = pub.CorrelationData rpub.Topic = pub.ResponseTopic rpub.MessageExpiry = pub.MessageExpiry maps.Copy(rpub.UserProperties, errutil.ToUserProp(resErr)) return rpub, nil } // Check whether this message should be ignored and why. func ignoreRequest(pub *mqtt.Message) error { if pub.ResponseTopic == "" { return &errors.Remote{ Message: "missing response topic", Kind: errors.HeaderMissing{ HeaderName: constants.ResponseTopic, }, } } if !internal.ValidTopic(pub.ResponseTopic) { return &errors.Remote{ Message: "invalid response topic", Kind: errors.HeaderInvalid{ HeaderName: constants.ResponseTopic, HeaderValue: pub.ResponseTopic, }, } } return nil } // Ack the request and log it. func (ce *CommandExecutor[Req, Res]) ack( ctx context.Context, pub *mqtt.Message, ) { pub.Ack() ce.log.Debug(ctx, "request acked", slog.String("topic", pub.Topic), slog.Any("correlation_data", pub.CorrelationData), ) } // Build a timeout based on the message's expiry. func pubTimeout(pub *mqtt.Message) *internal.Timeout { return &internal.Timeout{ Duration: time.Duration(pub.MessageExpiry) * time.Second, Name: "MessageExpiry", Text: commandExecutorErrStr, } } // Respond is a shorthand to create a command response with required values and // options set appropriately. Note that the response may be incomplete and will // be filled out by the library after being returned. func Respond[Res any]( payload Res, opt ...RespondOption, ) (*CommandResponse[Res], error) { var opts RespondOptions opts.Apply(opt) return &CommandResponse[Res]{Message[Res]{ Payload: payload, Metadata: opts.Metadata, }}, nil } // Apply resolves the provided list of options. func (o *CommandExecutorOptions) Apply( opts []CommandExecutorOption, rest ...CommandExecutorOption, ) { for opt := range options.Apply[CommandExecutorOption](opts, rest...) { opt.commandExecutor(o) } } // ApplyOptions filters and resolves the provided list of options. func (o *CommandExecutorOptions) ApplyOptions(opts []Option, rest ...Option) { for opt := range options.Apply[CommandExecutorOption](opts, rest...) { opt.commandExecutor(o) } } func (o *CommandExecutorOptions) commandExecutor(opt *CommandExecutorOptions) { if o != nil { *opt = *o } } func (*CommandExecutorOptions) option() {} func (o WithIdempotent) commandExecutor(opt *CommandExecutorOptions) { opt.Idempotent = bool(o) } func (WithIdempotent) option() {} // Apply resolves the provided list of options. func (o *RespondOptions) Apply( opts []RespondOption, rest ...RespondOption, ) { for opt := range options.Apply[RespondOption](opts, rest...) { opt.respond(o) } } func (o *RespondOptions) respond(opt *RespondOptions) { if o != nil { *opt = *o } }