sdk/messaging/azservicebus/internal/namespace.go (374 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package internal
import (
"context"
"crypto/tls"
"fmt"
"net"
"runtime"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry"
"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/auth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/conn"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/sbauth"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils"
"github.com/Azure/go-amqp"
)
var rootUserAgent = telemetry.Format("azservicebus", Version)
type (
// Namespace is an abstraction over an amqp.Client, allowing us to hold onto a single
// instance of a connection per ServiceBusClient.
Namespace struct {
// NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens
// is just to make it the first value in the struct
// See:
// Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// PR: https://github.com/Azure/azure-sdk-for-go/pull/16847
connID uint64
FQDN string
TokenProvider *sbauth.TokenProvider
tlsConfig *tls.Config
userAgent string
newWebSocketConn func(ctx context.Context, args exported.NewWebSocketConnArgs) (net.Conn, error)
// NOTE: exported only so it can be checked in a test
RetryOptions exported.RetryOptions
clientMu sync.RWMutex
client amqpwrap.AMQPClient
negotiateClaimMu sync.Mutex
// indicates that the client was closed permanently, and not just
// for recovery.
closedPermanently bool
// newClientFn exists so we can stub out newClient for unit tests.
newClientFn func(ctx context.Context) (amqpwrap.AMQPClient, error)
customEndpoint string
}
// NamespaceOption provides structure for configuring a new Service Bus namespace
NamespaceOption func(h *Namespace) error
)
// NamespaceForAMQPLinks is the Namespace surface needed for the internals of AMQPLinks.
type NamespaceForAMQPLinks interface {
Check() error
NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error)
NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error)
NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, error)
GetEntityAudience(entityPath string) string
// Recover destroys the currently held AMQP connection and recreates it, if needed.
//
// If a new client is actually created (rather than just cached) then the returned bool
// will be true. Any links that were created from the original connection will need to
// be recreated.
//
// NOTE: cancelling the context only cancels the initialization of a new AMQP
// connection - the previous connection is always closed.
Recover(ctx context.Context, clientRevision uint64) (bool, error)
Close(permanently bool) error
}
// NamespaceWithConnectionString configures a namespace with the information provided in a Service Bus connection string
func NamespaceWithConnectionString(connStr string) NamespaceOption {
return func(ns *Namespace) error {
props, err := conn.ParseConnectionString(connStr)
if err != nil {
return err
}
if props.FullyQualifiedNamespace != "" {
ns.FQDN = props.FullyQualifiedNamespace
}
provider, err := sbauth.NewTokenProviderWithConnectionString(props)
if err != nil {
return err
}
ns.TokenProvider = provider
return nil
}
}
// NamespaceWithCustomEndpoint sets a custom endpoint, useful for when you're connecting through a TCP proxy.
// When establishing a TCP connection we connect to this address. The audience is extracted from the
// fullyQualifiedNamespace given to NamespaceWithTokenCredential or the endpoint in the connection string passed
// to NamespaceWithConnectionString.
func NamespaceWithCustomEndpoint(customEndpoint string) NamespaceOption {
return func(ns *Namespace) error {
ns.customEndpoint = customEndpoint
return nil
}
}
// NamespaceWithTLSConfig appends to the TLS config.
func NamespaceWithTLSConfig(tlsConfig *tls.Config) NamespaceOption {
return func(ns *Namespace) error {
ns.tlsConfig = tlsConfig
return nil
}
}
// NamespaceWithUserAgent appends to the root user-agent value.
func NamespaceWithUserAgent(userAgent string) NamespaceOption {
return func(ns *Namespace) error {
ns.userAgent = userAgent
return nil
}
}
// NamespaceWithWebSocket configures the namespace and all entities to use wss:// rather than amqps://
func NamespaceWithWebSocket(newWebSocketConn func(ctx context.Context, args exported.NewWebSocketConnArgs) (net.Conn, error)) NamespaceOption {
return func(ns *Namespace) error {
ns.newWebSocketConn = newWebSocketConn
return nil
}
}
// NamespaceWithTokenCredential sets the token provider on the namespace
// fullyQualifiedNamespace is the Service Bus namespace name (ex: myservicebus.servicebus.windows.net)
func NamespaceWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption {
return func(ns *Namespace) error {
ns.TokenProvider = sbauth.NewTokenProvider(tokenCredential)
ns.FQDN = fullyQualifiedNamespace
return nil
}
}
func NamespaceWithRetryOptions(retryOptions exported.RetryOptions) NamespaceOption {
return func(ns *Namespace) error {
ns.RetryOptions = retryOptions
return nil
}
}
// NamespaceWithNewClientFn lets you inject a construction function to create new AMQP clients. Useful for tests.
func NamespaceWithNewClientFn(fn func(ctx context.Context) (amqpwrap.AMQPClient, error)) NamespaceOption {
return func(ns *Namespace) error {
ns.newClientFn = fn
return nil
}
}
// NewNamespace creates a new namespace configured through NamespaceOption(s)
func NewNamespace(opts ...NamespaceOption) (*Namespace, error) {
ns := &Namespace{}
ns.newClientFn = ns.newClientImpl
for _, opt := range opts {
err := opt(ns)
if err != nil {
return nil, err
}
}
return ns, nil
}
func (ns *Namespace) newClientImpl(ctx context.Context) (amqpwrap.AMQPClient, error) {
connOptions := amqp.ConnOptions{
SASLType: amqp.SASLTypeAnonymous(),
MaxSessions: 65535,
Properties: map[string]any{
"product": "MSGolangClient",
"version": Version,
"platform": runtime.GOOS,
"framework": runtime.Version(),
"user-agent": ns.getUserAgent(),
},
HostName: ns.FQDN,
}
if ns.tlsConfig != nil {
connOptions.TLSConfig = ns.tlsConfig
}
id, err := uuid.New()
if err != nil {
return nil, err
}
if ns.newWebSocketConn != nil {
nConn, err := ns.newWebSocketConn(ctx, exported.NewWebSocketConnArgs{
Host: ns.getWSSHostURI() + "$servicebus/websocket",
})
if err != nil {
return nil, err
}
connOptions.HostName = ns.FQDN
client, err := amqp.NewConn(ctx, nConn, &connOptions)
return &amqpwrap.AMQPClientWrapper{Inner: client, ID: id.String()}, err
}
client, err := amqp.Dial(ctx, ns.getAMQPHostURI(true), &connOptions)
return &amqpwrap.AMQPClientWrapper{Inner: client, ID: id.String()}, err
}
// NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client.
// Returns a closeable AMQP session and the current client revision.
func (ns *Namespace) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) {
client, clientRevision, err := ns.GetAMQPClientImpl(ctx)
if err != nil {
return nil, 0, err
}
session, err := client.NewSession(ctx, nil)
if err != nil {
return nil, 0, err
}
return session, clientRevision, err
}
// NewRPCLink creates a new amqp-common *rpc.Link with the internally cached *amqp.Client.
func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, error) {
client, _, err := ns.GetAMQPClientImpl(ctx)
if err != nil {
return nil, err
}
return NewRPCLink(ctx, RPCLinkArgs{
Client: client,
Address: managementPath,
LogEvent: exported.EventReceiver,
})
}
// Close closes the current cached client.
func (ns *Namespace) Close(permanently bool) error {
ns.clientMu.Lock()
defer ns.clientMu.Unlock()
if permanently {
ns.closedPermanently = true
}
if ns.client != nil {
err := ns.client.Close()
ns.client = nil
if err != nil {
log.Writef(exported.EventConn, "Failed when closing AMQP connection: %s", err)
}
}
return nil
}
// Check returns an error if the namespace cannot be used (ie, closed permanently), or nil otherwise.
func (ns *Namespace) Check() error {
ns.clientMu.RLock()
defer ns.clientMu.RUnlock()
if ns.closedPermanently {
return ErrClientClosed
}
return nil
}
var ErrClientClosed = NewErrNonRetriable("client has been closed by user")
// Recover destroys the currently held AMQP connection and recreates it, if needed.
//
// If a new client is actually created (rather than just cached) then the returned bool
// will be true. Any links that were created from the original connection will need to
// be recreated.
//
// NOTE: cancelling the context only cancels the initialization of a new AMQP
// connection - the previous connection is always closed.
func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) (bool, error) {
if err := ns.Check(); err != nil {
return false, err
}
ns.clientMu.Lock()
defer ns.clientMu.Unlock()
if ns.closedPermanently {
return false, ErrClientClosed
}
if ns.connID != theirConnID {
log.Writef(exported.EventConn, "Skipping connection recovery, already recovered: %d vs %d", ns.connID, theirConnID)
// we've already recovered since the client last tried.
return false, nil
}
if ns.client != nil {
oldClient := ns.client
ns.client = nil
// the error on close isn't critical
_ = oldClient.Close()
}
log.Writef(exported.EventConn, "Creating a new client (rev:%d)", ns.connID)
if _, _, err := ns.updateClientWithoutLock(ctx); err != nil {
return false, err
}
return true, nil
}
// negotiateClaim performs initial authentication and starts periodic refresh of credentials.
// the returned func is to cancel() the refresh goroutine.
func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) {
return ns.startNegotiateClaimRenewer(ctx,
entityPath,
NegotiateClaim,
nextClaimRefreshDuration)
}
// startNegotiateClaimRenewer does an initial claim request and then starts a goroutine that
// continues to automatically refresh in the background.
// Returns a func() that can be used to cancel the background renewal, a channel that will be closed
// when the background renewal stops or an error.
func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context,
entityPath string,
cbsNegotiateClaim func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error,
nextClaimRefreshDurationFn func(expirationTime time.Time, currentTime time.Time) time.Duration) (func(), <-chan struct{}, error) {
audience := ns.GetEntityAudience(entityPath)
refreshClaim := func(ctx context.Context) (time.Time, error) {
log.Writef(exported.EventAuth, "(%s) refreshing claim", entityPath)
amqpClient, clientRevision, err := ns.GetAMQPClientImpl(ctx)
if err != nil {
return time.Time{}, err
}
token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience)
if err != nil {
log.Writef(exported.EventAuth, "(%s) negotiate claim, failed getting token: %s", entityPath, err.Error())
return time.Time{}, err
}
log.Writef(exported.EventAuth, "(%s) negotiate claim, token expires on %s", entityPath, expiration.Format(time.RFC3339))
// You're not allowed to have multiple $cbs links open in a single connection.
// The current cbs.NegotiateClaim implementation automatically creates and shuts
// down it's own link so we have to guard against that here.
ns.negotiateClaimMu.Lock()
err = cbsNegotiateClaim(ctx, audience, amqpClient, token, context.WithTimeout)
ns.negotiateClaimMu.Unlock()
if err != nil {
// Note we only handle connection recovery here since (currently)
// the negotiateClaim code creates it's own link each time.
if GetRecoveryKind(err) == RecoveryKindConn {
if _, err := ns.Recover(ctx, clientRevision); err != nil {
log.Writef(exported.EventAuth, "(%s) negotiate claim, failed in connection recovery: %s", entityPath, err)
}
}
log.Writef(exported.EventAuth, "(%s) negotiate claim, failed: %s", entityPath, err.Error())
return time.Time{}, err
}
return expiration, nil
}
expiresOn, err := refreshClaim(ctx)
if err != nil {
return nil, nil, err
}
// start the periodic refresh of credentials
refreshCtx, cancelRefreshCtx := context.WithCancel(context.Background())
refreshStoppedCh := make(chan struct{})
// connection strings with embedded SAS tokens will return a zero expiration time since they can't be renewed.
if expiresOn.IsZero() {
log.Writef(exported.EventAuth, "Token does not have an expiration date, no background renewal needed.")
// cancel everything related to the claims refresh loop.
cancelRefreshCtx()
close(refreshStoppedCh)
return func() {}, refreshStoppedCh, nil
}
go func() {
defer cancelRefreshCtx()
defer close(refreshStoppedCh)
TokenRefreshLoop:
for {
nextClaimAt := nextClaimRefreshDurationFn(expiresOn, time.Now())
log.Writef(exported.EventAuth, "(%s) next refresh in %s", entityPath, nextClaimAt)
select {
case <-refreshCtx.Done():
return
case <-time.After(nextClaimAt):
for {
err := utils.Retry(refreshCtx, exported.EventAuth, "NegotiateClaimRefresh", func(ctx context.Context, args *utils.RetryFnArgs) error {
tmpExpiresOn, err := refreshClaim(ctx)
if err != nil {
return err
}
expiresOn = tmpExpiresOn
return nil
}, IsFatalSBError, ns.RetryOptions)
if err == nil {
break
}
if GetRecoveryKind(err) == RecoveryKindFatal {
log.Writef(exported.EventAuth, "[%s] fatal error, stopping token refresh loop: %s", entityPath, err.Error())
break TokenRefreshLoop
}
}
}
}
}()
return func() {
cancelRefreshCtx()
<-refreshStoppedCh
}, refreshStoppedCh, nil
}
func (ns *Namespace) GetAMQPClientImpl(ctx context.Context) (amqpwrap.AMQPClient, uint64, error) {
if err := ns.Check(); err != nil {
return nil, 0, err
}
ns.clientMu.Lock()
defer ns.clientMu.Unlock()
if ns.closedPermanently {
return nil, 0, ErrClientClosed
}
return ns.updateClientWithoutLock(ctx)
}
// updateClientWithoutLock takes care of initializing a client (if needed)
// and returns the initialized client and it's connection ID, or an error.
func (ns *Namespace) updateClientWithoutLock(ctx context.Context) (amqpwrap.AMQPClient, uint64, error) {
if ns.client != nil {
return ns.client, ns.connID, nil
}
connStart := time.Now()
log.Writef(exported.EventConn, "Creating new client, current rev: %d", ns.connID)
tempClient, err := ns.newClientFn(ctx)
if err != nil {
return nil, 0, err
}
ns.connID++
ns.client = tempClient
log.Writef(exported.EventConn, "Client created, new rev: %d, took %dms", ns.connID, time.Since(connStart)/time.Millisecond)
return ns.client, ns.connID, err
}
func (ns *Namespace) getWSSHostURI() string {
return fmt.Sprintf("wss://%s/", ns.FQDN)
}
func (ns *Namespace) getAMQPHostURI(useCustomEndpoint bool) string {
fqdn := ns.FQDN
if useCustomEndpoint && ns.customEndpoint != "" {
fqdn = ns.customEndpoint
}
if ns.TokenProvider.InsecureDisableTLS {
return fmt.Sprintf("amqp://%s/", fqdn)
} else {
return fmt.Sprintf("amqps://%s/", fqdn)
}
}
func (ns *Namespace) GetHTTPSHostURI() string {
return fmt.Sprintf("https://%s/", ns.FQDN)
}
func (ns *Namespace) GetEntityAudience(entityPath string) string {
return ns.getAMQPHostURI(false) + entityPath
}
func (ns *Namespace) getUserAgent() string {
userAgent := rootUserAgent
if ns.userAgent != "" {
userAgent = fmt.Sprintf("%s %s", ns.userAgent, userAgent)
}
return userAgent
}
// nextClaimRefreshDuration figures out the proper interval for the next authorization
// refresh.
//
// It applies a few real world adjustments:
// - We assume the expiration time is 10 minutes ahead of when it actually is, to adjust for clock drift.
// - We don't let the refresh interval fall below 2 minutes
// - We don't let the refresh interval go above 49 days
//
// This logic is from here:
// https://github.com/Azure/azure-sdk-for-net/blob/bfd3109d0f9afa763131731d78a31e39c81101b3/sdk/servicebus/Azure.Messaging.ServiceBus/src/Amqp/AmqpConnectionScope.cs#L998
func nextClaimRefreshDuration(expirationTime time.Time, currentTime time.Time) time.Duration {
const min = 2 * time.Minute
const max = 49 * 24 * time.Hour
const clockDrift = 10 * time.Minute
var refreshDuration = expirationTime.Sub(currentTime) - clockDrift
if refreshDuration < min {
return min
} else if refreshDuration > max {
return max
}
return refreshDuration
}