v2/sender.go (206 lines of code) (raw):

package shuttle import ( "context" "fmt" "reflect" "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus" "github.com/Azure/go-shuttle/v2/metrics/sender" ) const ( msgTypeField = "type" defaultSendTimeout = 30 * time.Second ) // MessageBody is a type to represent that an input message body can be of any type type MessageBody any // AzServiceBusSender is satisfied by *azservicebus.Sender type AzServiceBusSender interface { SendMessage(ctx context.Context, message *azservicebus.Message, options *azservicebus.SendMessageOptions) error SendMessageBatch(ctx context.Context, batch *azservicebus.MessageBatch, options *azservicebus.SendMessageBatchOptions) error NewMessageBatch(ctx context.Context, options *azservicebus.MessageBatchOptions) (*azservicebus.MessageBatch, error) Close(ctx context.Context) error } // Sender contains an SBSender used to send the message to the ServiceBus queue and a Marshaller used to marshal any struct into a ServiceBus message type Sender struct { // sbSender is responsible for message sending. It is protected by a mutex to prevent race conditions. // Any usage of sbSender must consider the mutex and take the appropriate lock. sbSender AzServiceBusSender // mu is used to prevent race conditions when changing the sbSender instance. // Use mu.Lock() only during FailOver and mu.RLock() for all other operations mu sync.RWMutex options *SenderOptions } type SenderOptions struct { // Marshaller will be used to marshall the messageBody to the azservicebus.Message Body property // defaults to DefaultJSONMarshaller Marshaller Marshaller // EnableTracingPropagation automatically applies WithTracePropagation option on all message sent through this sender EnableTracingPropagation bool // SendTimeout is the timeout value used on the context that sends messages // Defaults to 30 seconds if not set or 0 // Disabled when set to a negative value SendTimeout time.Duration } // NewSender takes in a Sender and a Marshaller to create a new object that can send messages to the ServiceBus queue func NewSender(sender AzServiceBusSender, options *SenderOptions) *Sender { if options == nil { options = &SenderOptions{} } if options.Marshaller == nil { options.Marshaller = &DefaultJSONMarshaller{} } if options.SendTimeout == 0 { options.SendTimeout = defaultSendTimeout } return &Sender{sbSender: sender, options: options} } // SendMessage sends a payload on the bus. // the MessageBody is marshalled and set as the message body. func (d *Sender) SendMessage(ctx context.Context, mb MessageBody, options ...func(msg *azservicebus.Message) error) error { // Check if there is a context error before doing anything since // we rely on context failures to detect if the sender is dead. if ctx.Err() != nil { return fmt.Errorf("failed to send message: %w", ctx.Err()) } msg, err := d.ToServiceBusMessage(ctx, mb, options...) if err != nil { return err } if d.options.SendTimeout > 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, d.options.SendTimeout) defer cancel() } errChan := make(chan error) go func() { err := d.sendMessage(ctx, msg, nil) // sendMessageOptions currently does nothing if err != nil { errChan <- fmt.Errorf("failed to send message: %w", err) } else { errChan <- nil } }() select { case <-ctx.Done(): sender.Metric.IncSendMessageFailureCount() return fmt.Errorf("failed to send message: %w", ctx.Err()) case err := <-errChan: if err == nil { sender.Metric.IncSendMessageSuccessCount() } else { sender.Metric.IncSendMessageFailureCount() } return err } } // ToServiceBusMessage transform a MessageBody into an azservicebus.Message. // It marshals the body using the sender's configured marshaller, // and set the bytes as the message.Body. // the sender's configured options are applied to the azservicebus.Message before // returning it. func (d *Sender) ToServiceBusMessage( ctx context.Context, mb MessageBody, options ...func(msg *azservicebus.Message) error) (*azservicebus.Message, error) { // uses a marshaller to marshal the message into a service bus message msg, err := d.options.Marshaller.Marshal(mb) if err != nil { return nil, fmt.Errorf("failed to marshal original struct into ServiceBus message: %w", err) } msgType := getMessageType(mb) msg.ApplicationProperties = map[string]interface{}{msgTypeField: msgType} if d.options.EnableTracingPropagation { options = append(options, WithTracePropagation(ctx)) } for _, option := range options { if err := option(msg); err != nil { return nil, fmt.Errorf("failed to run message options: %w", err) } } return msg, nil } // SendMessageBatch sends the array of azservicebus messages as a batch. func (d *Sender) SendMessageBatch(ctx context.Context, messages []*azservicebus.Message) error { // Check if there is a context error before doing anything since // we rely on context failures to detect if the sender is dead. if ctx.Err() != nil { return fmt.Errorf("failed to send message: %w", ctx.Err()) } batch, err := d.newMessageBatch(ctx, &azservicebus.MessageBatchOptions{}) if err != nil { return err } for _, msg := range messages { if err := batch.AddMessage(msg, nil); err != nil { return err } } if d.options.SendTimeout > 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, d.options.SendTimeout) defer cancel() } errChan := make(chan error) go func() { if err := d.sendMessageBatch(ctx, batch, nil); err != nil { errChan <- fmt.Errorf("failed to send message batch: %w", err) } else { errChan <- nil } }() select { case <-ctx.Done(): sender.Metric.IncSendMessageFailureCount() return fmt.Errorf("failed to send message batch: %w", ctx.Err()) case err := <-errChan: if err == nil { sender.Metric.IncSendMessageSuccessCount() } else { sender.Metric.IncSendMessageFailureCount() } return err } } func (d *Sender) sendMessage(ctx context.Context, msg *azservicebus.Message, options *azservicebus.SendMessageOptions) error { d.mu.RLock() defer d.mu.RUnlock() return d.sbSender.SendMessage(ctx, msg, options) } func (d *Sender) sendMessageBatch(ctx context.Context, batch *azservicebus.MessageBatch, options *azservicebus.SendMessageBatchOptions) error { d.mu.RLock() defer d.mu.RUnlock() return d.sbSender.SendMessageBatch(ctx, batch, options) } func (d *Sender) newMessageBatch(ctx context.Context, options *azservicebus.MessageBatchOptions) (*azservicebus.MessageBatch, error) { d.mu.RLock() defer d.mu.RUnlock() return d.sbSender.NewMessageBatch(ctx, options) } // AzSender returns the underlying azservicebus.Sender instance. func (d *Sender) AzSender() AzServiceBusSender { d.mu.RLock() defer d.mu.RUnlock() return d.sbSender } // SetAzSender sets the underlying azservicebus.Sender instance to the provided one. // All ongoing send operations will continue to use the old sender instance, // while new send operations will use the new sender instance. func (d *Sender) SetAzSender(sender AzServiceBusSender) { d.mu.Lock() defer d.mu.Unlock() d.sbSender = sender } // Deprecated: use SetAzSender. // FailOver sets the underlying azservicebus.Sender instance to the provided one. func (d *Sender) FailOver(sender AzServiceBusSender) { d.SetAzSender(sender) } // SetMessageId sets the ServiceBus message's ID to a user-specified value func SetMessageId(messageId *string) func(msg *azservicebus.Message) error { return func(msg *azservicebus.Message) error { msg.MessageID = messageId return nil } } // SetCorrelationId sets the ServiceBus message's correlation ID to a user-specified value func SetCorrelationId(correlationId *string) func(msg *azservicebus.Message) error { return func(msg *azservicebus.Message) error { msg.CorrelationID = correlationId return nil } } // SetScheduleAt schedules a message to be enqueued in the future func SetScheduleAt(t time.Time) func(msg *azservicebus.Message) error { return func(msg *azservicebus.Message) error { msg.ScheduledEnqueueTime = &t return nil } } // SetMessageDelay schedules a message in the future func SetMessageDelay(delay time.Duration) func(msg *azservicebus.Message) error { return func(msg *azservicebus.Message) error { newTime := time.Now().Add(delay) msg.ScheduledEnqueueTime = &newTime return nil } } // SetMessageTTL sets the ServiceBus message's TimeToLive to a user-specified value func SetMessageTTL(ttl time.Duration) func(msg *azservicebus.Message) error { return func(msg *azservicebus.Message) error { msg.TimeToLive = &ttl return nil } } func getMessageType(mb MessageBody) string { var msgType string vo := reflect.ValueOf(mb) if vo.Kind() == reflect.Ptr { msgType = reflect.Indirect(vo).Type().Name() } else { msgType = vo.Type().Name() } return msgType }