secretmgmt/secrets.go (96 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package secretmgmt
import (
"context"
"fmt"
"time"
"github.com/Azure/acr-builder/tokenutil"
"github.com/Azure/acr-builder/vaults"
"github.com/pkg/errors"
)
const (
// DefaultSecretResolveTimeout is the default timeout for resolving a secret which is 2 minute
DefaultSecretResolveTimeout time.Duration = time.Minute * 2
)
type secretResolveChannel struct {
resolvedChan chan bool
timeoutChan func() <-chan struct{}
}
// ResolveSecretFunc is a function that resolves the secret to its value and sends through the ResolvedChan of the secret. Any errors during resolve are send through errorChan
type ResolveSecretFunc func(ctx context.Context, secret *Secret, errorChan chan error)
// SecretResolver defines how a secret is resolved.
type SecretResolver struct {
Resolve ResolveSecretFunc
resolveTimeout time.Duration
}
// NewSecretResolver creates a resolver with the given resolve function.
func NewSecretResolver(resolveFunc ResolveSecretFunc, resolveTimeout time.Duration) (*SecretResolver, error) {
if resolveFunc == nil {
resolveFunc = resolveSecret
}
return &SecretResolver{Resolve: resolveFunc, resolveTimeout: resolveTimeout}, nil
}
// ResolveSecrets resolves all the Secrets, or returns an error if there is any failure in resolving a secret.
func (secretResolver *SecretResolver) ResolveSecrets(ctx context.Context, secrets []*Secret) error {
if len(secrets) == 0 {
return nil
}
// We will resolve in batches of 5 to avoid throttling errors on the vault providers
batchSize := 5
errorChan := make(chan error)
for index := 0; index < len(secrets); index += batchSize {
endIndex := index + batchSize
if endIndex > len(secrets) {
endIndex = len(secrets)
}
var secretChannels []secretResolveChannel
for _, secret := range secrets[index:endIndex] {
if secret == nil {
continue
}
if secret.ResolvedChan == nil {
secret.ResolvedChan = make(chan bool)
}
ctxWithTimeout, cancel := context.WithTimeout(ctx, secretResolver.resolveTimeout)
defer cancel()
secretChannels = append(secretChannels, secretResolveChannel{secret.ResolvedChan, ctxWithTimeout.Done})
go secretResolver.Resolve(ctxWithTimeout, secret, errorChan)
}
// Block until either:
// - timeout in fetching any of the secrets.
// - The global context expires
// - Resolving a secret has error
// - All secrets are resolved successfully
for _, ch := range secretChannels {
select {
case <-ch.timeoutChan():
return errors.New("timeout in fetching secrets. please check permissions are valid")
case <-ctx.Done():
return ctx.Err()
case <-ch.resolvedChan:
case err := <-errorChan:
return err
}
}
}
return nil
}
func resolveSecret(ctx context.Context, secret *Secret, errorChan chan error) {
if secret == nil {
errorChan <- errors.New("secret cannot be nil")
return
}
if secret.IsKeyVaultSecret() {
secretConfig, err := vaults.NewAKVSecretConfig(secret.KeyVault, secret.MsiClientID)
if err != nil {
errorChan <- err
return
}
secretValue, err := secretConfig.GetValue(ctx)
if err != nil {
errorChan <- err
return
}
secret.ResolvedValue = secretValue
secret.ResolvedChan <- true
return
} else if secret.IsMsiSecret() {
secretValue, err := tokenutil.GetRegistryRefreshToken(secret.ID, secret.AadResourceID, secret.MsiClientID)
if err != nil {
errorChan <- err
return
}
secret.ResolvedValue = secretValue
secret.ResolvedChan <- true
return
}
errorChan <- fmt.Errorf("cannot resolve secret with ID: %s", secret.ID)
}