vaults/azurekeyvault.go (100 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package vaults
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/Azure/acr-builder/tokenutil"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/pkg/errors"
)
// AKVSecretConfig provides the options to get secret from Azure keyvault using MSI.
type AKVSecretConfig struct {
VaultURL string
SecretName string
SecretVersion string
MSIClientID string
AADResourceURL string
}
// GetValue gets the secret vaule as defined by the config from Azure key vault using MSI.
func (secretConfig *AKVSecretConfig) GetValue(ctx context.Context) (string, error) {
if secretConfig == nil {
return "", errors.New("secret config is required")
}
if secretConfig.VaultURL == "" ||
secretConfig.SecretName == "" ||
secretConfig.AADResourceURL == "" {
return "", errors.New("missing required properties VaultURL, SecretName, and AADResourceURL")
}
keyClient, err := newKeyVaultClient(secretConfig.VaultURL, secretConfig.MSIClientID, secretConfig.AADResourceURL)
if err != nil {
return "", err
}
secretValue, err := keyClient.getSecret(ctx, secretConfig.SecretName, secretConfig.SecretVersion)
if err != nil {
return "", err
}
return secretValue, nil
}
// NewAKVSecretConfig creates the Azure Key Vault config.
func NewAKVSecretConfig(vaultURL, msiClientID string) (*AKVSecretConfig, error) {
if vaultURL == "" {
return nil, errors.New("missing azure keyvault URL")
}
normalizedVaultURL := strings.TrimSuffix(strings.ToLower(vaultURL), "/")
parsedURL, err := url.Parse(normalizedVaultURL)
if err != nil {
return nil, errors.Wrap(err, "failed to parse the azure keyvault secret URL")
}
if parsedURL.Scheme != "https" {
return nil, errors.New("invalid azure keyvault secret URL scheme. Expected Https")
}
urlSegments := strings.Split(parsedURL.Path, "/")
if len(urlSegments) != 3 && len(urlSegments) != 4 {
return nil, fmt.Errorf("invalid azure keyvault secret URL. Bad number of URL segments: %d", len(urlSegments))
}
if urlSegments[1] != "secrets" {
return nil, fmt.Errorf("invalid azure keyvault secret URL. Expected 'secrets' collection, but found: %s", urlSegments[1])
}
secretVersion := ""
if len(urlSegments) == 4 {
secretVersion = urlSegments[3]
}
vaultHostWithScheme := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
splitStr := strings.SplitAfterN(vaultHostWithScheme, ".", 2)
// Ex. https://myacbvault.vault.azure.net -> ["https://myacbvault." "vault.azure.net"]
if len(splitStr) != 2 {
return nil, fmt.Errorf("extracted vault resource %s from vault URL %s is invalid", vaultHostWithScheme, vaultURL)
}
// Ex. https://vault.azure.net
vaultAADResourceURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, splitStr[1])
akvConfig := &AKVSecretConfig{
VaultURL: vaultHostWithScheme,
SecretName: urlSegments[2],
SecretVersion: secretVersion,
MSIClientID: msiClientID,
AADResourceURL: vaultAADResourceURL,
}
return akvConfig, nil
}
// keyVault holds the information for a keyvault instance
type keyVault struct {
client *keyvault.BaseClient
vaultURL string
}
// newKeyVaultClient creates a new keyvault client
func newKeyVaultClient(vaultURL, clientID, vaultAADResourceURL string) (*keyVault, error) {
spToken, err := tokenutil.GetServicePrincipalToken(vaultAADResourceURL, clientID)
if err != nil {
return nil, err
}
authorizer := autorest.NewBearerAuthorizer(spToken)
keyClient := keyvault.New()
keyClient.Authorizer = authorizer
k := &keyVault{
vaultURL: vaultURL,
client: &keyClient,
}
return k, nil
}
// getSecret retrieves a secret from keyvault
func (k *keyVault) getSecret(ctx context.Context, secretName, secretVersion string) (string, error) {
secretBundle, err := k.client.GetSecret(ctx, k.vaultURL, secretName, secretVersion)
if err != nil {
return "", err
}
return *secretBundle.Value, nil
}