router/pkg/pubsub/nats/adapter.go (249 lines of code) (raw):
package nats
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/cespare/xxhash/v2"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"github.com/wundergraph/cosmo/router/pkg/pubsub/datasource"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"go.uber.org/zap"
)
// Adapter defines the methods that a NATS adapter should implement
type Adapter interface {
// Subscribe subscribes to the given events and sends updates to the updater
Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error
// Publish publishes the given event to the specified subject
Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error
// Request sends a request to the specified subject and writes the response to the given writer
Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error
// Startup initializes the adapter
Startup(ctx context.Context) error
// Shutdown gracefully shuts down the adapter
Shutdown(ctx context.Context) error
}
// ProviderAdapter implements the AdapterInterface for NATS pub/sub
type ProviderAdapter struct {
ctx context.Context
client *nats.Conn
js jetstream.JetStream
logger *zap.Logger
closeWg sync.WaitGroup
hostName string
routerListenAddr string
url string
opts []nats.Option
flushTimeout time.Duration
}
// getInstanceIdentifier returns an identifier for the current instance.
// We use the hostname and the address the router is listening on, which should provide a good representation
// of what a unique instance is from the perspective of the client that has started a subscription to this instance
// and want to restart the subscription after a failure on the client or router side.
func (p *ProviderAdapter) getInstanceIdentifier() string {
return fmt.Sprintf("%s-%s", p.hostName, p.routerListenAddr)
}
// getDurableConsumerName returns the durable consumer name based on the given subjects and the instance id
// we need to make sure that the durable consumer name is unique for each instance and subjects to prevent
// multiple routers from changing the same consumer, which would lead to message loss and wrong messages delivered
// to the subscribers
func (p *ProviderAdapter) getDurableConsumerName(durableName string, subjects []string) (string, error) {
subjHash := xxhash.New()
_, err := subjHash.WriteString(p.getInstanceIdentifier())
if err != nil {
return "", err
}
for _, subject := range subjects {
_, err = subjHash.WriteString(subject)
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil
}
func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error {
log := p.logger.With(
zap.String("provider_id", event.ProviderID),
zap.String("method", "subscribe"),
zap.Strings("subjects", event.Subjects),
)
if p.client == nil {
return datasource.NewError("nats client not initialized", nil)
}
if p.js == nil {
return datasource.NewError("nats jetstream not initialized", nil)
}
if event.StreamConfiguration != nil {
durableConsumerName, err := p.getDurableConsumerName(event.StreamConfiguration.Consumer, event.Subjects)
if err != nil {
return err
}
consumerConfig := jetstream.ConsumerConfig{
Durable: durableConsumerName,
FilterSubjects: event.Subjects,
}
// Durable consumers are removed automatically only if the InactiveThreshold value is set
if event.StreamConfiguration.ConsumerInactiveThreshold > 0 {
consumerConfig.InactiveThreshold = time.Duration(event.StreamConfiguration.ConsumerInactiveThreshold) * time.Second
}
consumer, err := p.js.CreateOrUpdateConsumer(ctx, event.StreamConfiguration.StreamName, consumerConfig)
if err != nil {
log.Error("creating or updating consumer", zap.Error(err))
return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, event.StreamConfiguration.StreamName), err)
}
p.closeWg.Add(1)
go func() {
defer p.closeWg.Done()
for {
select {
case <-p.ctx.Done():
// When the application context is done, we stop the subscription
return
case <-ctx.Done():
// When the subscription context is done, we stop the subscription
return
default:
msgBatch, consumerFetchErr := consumer.FetchNoWait(300)
if consumerFetchErr != nil {
log.Error("error fetching messages", zap.Error(consumerFetchErr))
return
}
for msg := range msgBatch.Messages() {
log.Debug("subscription update", zap.String("message_subject", msg.Subject()), zap.ByteString("data", msg.Data()))
updater.Update(msg.Data())
// Acknowledge the message after it has been processed
ackErr := msg.Ack()
if ackErr != nil {
log.Error("error acknowledging message", zap.String("message_subject", msg.Subject()), zap.Error(ackErr))
return
}
}
}
}
}()
return nil
}
msgChan := make(chan *nats.Msg)
subscriptions := make([]*nats.Subscription, len(event.Subjects))
for i, subject := range event.Subjects {
subscription, err := p.client.ChanSubscribe(subject, msgChan)
if err != nil {
log.Error("subscribing to NATS subject", zap.Error(err), zap.String("subscription_subject", subject))
return datasource.NewError(fmt.Sprintf(`failed to subscribe to NATS subject "%s"`, subject), err)
}
subscriptions[i] = subscription
}
p.closeWg.Add(1)
go func() {
defer p.closeWg.Done()
for {
select {
case msg := <-msgChan:
log.Debug("subscription update", zap.String("message_subject", msg.Subject), zap.ByteString("data", msg.Data))
updater.Update(msg.Data)
case <-p.ctx.Done():
// When the application context is done, we stop the subscriptions
for _, subscription := range subscriptions {
if err := subscription.Unsubscribe(); err != nil {
log.Error("unsubscribing from NATS subject after application context cancellation",
zap.Error(err), zap.String("subject", subscription.Subject),
)
}
}
return
case <-ctx.Done():
// When the subscription context is done, we stop the subscription
for _, subscription := range subscriptions {
if err := subscription.Unsubscribe(); err != nil {
log.Error("unsubscribing from NATS subject after subscription context cancellation",
zap.Error(err), zap.String("subscription_subject", subscription.Subject),
)
}
}
return
}
}
}()
return nil
}
func (p *ProviderAdapter) Publish(_ context.Context, event PublishAndRequestEventConfiguration) error {
log := p.logger.With(
zap.String("provider_id", event.ProviderID),
zap.String("method", "publish"),
zap.String("subject", event.Subject),
)
if p.client == nil {
return datasource.NewError("nats client not initialized", nil)
}
log.Debug("publish", zap.ByteString("data", event.Data))
err := p.client.Publish(event.Subject, event.Data)
if err != nil {
log.Error("publish error", zap.Error(err))
return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err)
}
return nil
}
func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error {
log := p.logger.With(
zap.String("provider_id", event.ProviderID),
zap.String("method", "request"),
zap.String("subject", event.Subject),
)
if p.client == nil {
return datasource.NewError("nats client not initialized", nil)
}
log.Debug("request", zap.ByteString("data", event.Data))
msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Data)
if err != nil {
log.Error("request error", zap.Error(err))
return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", event.Subject), err)
}
_, err = w.Write(msg.Data)
if err != nil {
log.Error("error writing response to writer", zap.Error(err))
return err
}
return err
}
func (p *ProviderAdapter) flush(ctx context.Context) error {
if p.client == nil {
return nil
}
_, ok := ctx.Deadline()
if !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.flushTimeout)
defer cancel()
}
return p.client.FlushWithContext(ctx)
}
func (p *ProviderAdapter) Startup(ctx context.Context) (err error) {
p.client, err = nats.Connect(p.url, p.opts...)
if err != nil {
return err
}
p.js, err = jetstream.New(p.client)
if err != nil {
return err
}
return nil
}
func (p *ProviderAdapter) Shutdown(ctx context.Context) error {
if p.client == nil {
return nil
}
if p.client.IsClosed() {
return nil // Already disconnected or failed to connect
}
var shutdownErr error
fErr := p.flush(ctx)
if fErr != nil {
shutdownErr = errors.Join(shutdownErr, fErr)
}
drainErr := p.client.Drain()
if drainErr != nil {
shutdownErr = errors.Join(shutdownErr, drainErr)
}
// Wait for all subscriptions to be closed
p.closeWg.Wait()
if shutdownErr != nil {
return fmt.Errorf("nats pubsub shutdown: %w", shutdownErr)
}
return nil
}
func NewAdapter(ctx context.Context, logger *zap.Logger, url string, opts []nats.Option, hostName string, routerListenAddr string) (Adapter, error) {
if logger == nil {
logger = zap.NewNop()
}
return &ProviderAdapter{
ctx: ctx,
logger: logger.With(zap.String("pubsub", "nats")),
closeWg: sync.WaitGroup{},
hostName: hostName,
routerListenAddr: routerListenAddr,
url: url,
opts: opts,
flushTimeout: 10 * time.Second,
}, nil
}