registry/storage/driver/azure/v2/auth.go (231 lines of code) (raw):

package v2 import ( "context" "fmt" "net/http" "strings" "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" "github.com/docker/distribution/registry/storage/driver/azure/common" "github.com/docker/distribution/registry/storage/driver/base" "github.com/docker/distribution/version" ) const ( // NOTE(prozlach): Time period in the past, starting from time.Now(), for which // the signed URL is already valid. Meant to minimize issues caused by small time // synchronization issues. Value follows the one set by upstream. clockSkewTolerance = 10 * time.Second UDCGracePeriod = 30.0 * time.Minute UDCExpiryTime = 48.0 * time.Hour ) // userAgent is a string that is going to be appended to the User-Agent string // generated by Azure SDK, it already contains golang version, so we only add // Container-registry version here. var userAgent = fmt.Sprintf("cr/%s", version.Version) // ensureUseragentMaxLengthNoSpaces makes the user agent conform to the // requirements of azure sdk wrt. user agent length and composition. It is // necessary as some of the user-agent information is passed from the // buildscripts. func ensureUseragentMaxLengthNoSpaces(s string) string { noSpaces := strings.ReplaceAll(s, " ", "_") if len(noSpaces) > 24 { return noSpaces[:24] } return noSpaces } // signer abstracts the specifics of a blob SAS and is specialized // for the different authentication credentials type signer interface { Sign(context.Context, *sas.BlobSignatureValues) (sas.QueryParameters, error) } type urlSigner interface { SignBlobURL(context.Context, string, time.Time) (string, error) } var _ signer = (*sharedKeySigner)(nil) type sharedKeySigner struct { cred *azblob.SharedKeyCredential } func (s *sharedKeySigner) Sign(_ context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) { return signatureValues.SignWithSharedKey(s.cred) } var _ signer = (*clientTokenSigner)(nil) type clientTokenSigner struct { client *azblob.Client cred azcore.TokenCredential udc *service.UserDelegationCredential udcMutex sync.Mutex udcExpiry time.Time } func (s *clientTokenSigner) refreshUDC(ctx context.Context) (*service.UserDelegationCredential, error) { s.udcMutex.Lock() defer s.udcMutex.Unlock() now := time.Now().UTC() if s.udc == nil || s.udcExpiry.Sub(now) < UDCGracePeriod { // reissue user delegation credential startTime := now.Add(-10 * time.Second) expiryTime := startTime.Add(UDCExpiryTime) info := service.KeyInfo{ Start: to.Ptr(startTime.UTC().Format(sas.TimeFormat)), Expiry: to.Ptr(expiryTime.UTC().Format(sas.TimeFormat)), } udc, err := s.client.ServiceClient().GetUserDelegationCredential(ctx, info, nil) if err != nil { return nil, fmt.Errorf("creating user delegation credentials: %w", err) } s.udc = udc s.udcExpiry = expiryTime } return s.udc, nil } func (s *clientTokenSigner) Sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) { udc, err := s.refreshUDC(ctx) if err != nil { return sas.QueryParameters{}, fmt.Errorf("refreshing UDC credentials: %w", err) } return signatureValues.SignWithUserDelegation(udc) } var _ urlSigner = (*urlSignerImpl)(nil) type urlSignerImpl struct { si signer } func (s *urlSignerImpl) SignBlobURL(ctx context.Context, blobURL string, expires time.Time) (string, error) { urlParts, err := sas.ParseURL(blobURL) if err != nil { return "", fmt.Errorf("parsing url %q to be signed: %w", urlParts.String(), err) } perms := sas.BlobPermissions{Read: true} signatureValues := &sas.BlobSignatureValues{ Protocol: sas.ProtocolHTTPS, StartTime: time.Now().UTC().Add(-1 * clockSkewTolerance), ExpiryTime: expires, Permissions: perms.String(), ContainerName: urlParts.ContainerName, BlobName: urlParts.BlobName, } urlParts.SAS, err = s.si.Sign(ctx, signatureValues) if err != nil { return "", fmt.Errorf("signing URL %q: %w", urlParts.String(), err) } return urlParts.String(), nil } func newSharedKeyCredentialsClient(params *DriverParameters) (*Driver, error) { cred, err := azblob.NewSharedKeyCredential(params.AccountName, params.AccountKey) if err != nil { return nil, fmt.Errorf("creating shared key credentials: %w", err) } opts := azcore.ClientOptions{ PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()}, Logging: policy.LogOptions{ AllowedHeaders: []string{ "x-ms-error-code", "Retry-After", "Retry-After-Ms", "If-Match", "x-ms-blob-condition-appendpos", }, AllowedQueryParams: []string{"comp"}, }, Telemetry: policy.TelemetryOptions{ ApplicationID: ensureUseragentMaxLengthNoSpaces(userAgent), }, } if params.Transport != nil { opts.Transport = &http.Client{ Transport: params.Transport, } } client, err := azblob.NewClientWithSharedKeyCredential( params.ServiceURL, cred, &azblob.ClientOptions{ ClientOptions: opts, }, ) if err != nil { return nil, fmt.Errorf("creating client using shared key credentials: %w", err) } d := &driver{ client: client.ServiceClient().NewContainerClient(params.Container), signer: &urlSignerImpl{ si: &sharedKeySigner{ cred: cred, }, }, } commonClientSetup(params, d) return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil } func newTokenClient(params *DriverParameters) (*Driver, error) { var cred azcore.TokenCredential var err error opts := azcore.ClientOptions{ PerRetryPolicies: []policy.Policy{newRetryNotificationPolicy()}, Logging: policy.LogOptions{ AllowedHeaders: []string{ "x-ms-error-code", "Retry-After", "Retry-After-Ms", "If-Match", "x-ms-blob-condition-appendpos", }, AllowedQueryParams: []string{"comp"}, }, Telemetry: policy.TelemetryOptions{ ApplicationID: ensureUseragentMaxLengthNoSpaces(userAgent), }, } if params.Transport != nil { opts.Transport = &http.Client{ Transport: params.Transport, } } if params.CredentialsType == common.CredentialsTypeClientSecret { cred, err = azidentity.NewClientSecretCredential( params.TenantID, params.ClientID, params.Secret, nil, ) if err != nil { return nil, fmt.Errorf("creating new client-secret credential: %w", err) } } else { // params.credentialsType == credentialsTypeDefaultCredentials cred, err = azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, fmt.Errorf("creating default azure credentials: %w", err) } } client, err := azblob.NewClient( params.ServiceURL, cred, &azblob.ClientOptions{ ClientOptions: opts, }, ) if err != nil { return nil, fmt.Errorf("creating azure client: %w", err) } d := &driver{ client: client.ServiceClient().NewContainerClient(params.Container), signer: &urlSignerImpl{ si: &clientTokenSigner{ cred: cred, client: client, }, }, } commonClientSetup(params, d) return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil } func commonClientSetup(params *DriverParameters, d *driver) { d.Pather = common.NewPather(params.Root, !params.TrimLegacyRootPrefix) d.poolInitialInterval = params.PoolInitialInterval d.poolMaxInterval = params.PoolMaxInterval d.poolMaxElapsedTime = params.PoolMaxElapsedTime d.maxRetries = params.MaxRetries d.retryTryTimeout = params.RetryTryTimeout d.retryDelay = params.RetryDelay d.maxRetryDelay = params.MaxRetryDelay if params.DebugLog { if len(params.DebugLogEvents) > 0 { azlog.SetEvents(params.DebugLogEvents...) } logger := params.Logger azlog.SetListener(func(cls azlog.Event, msg string) { logger.WithField("event_type", cls).Debug(msg) }) } }