sdk/storage/azblob/internal/testcommon/clients_auth.go (331 lines of code) (raw):
//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
// Contains common helpers for TESTS ONLY
package testcommon
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/Azure/azure-sdk-for-go/sdk/internal/test/credential"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/appendblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service"
"github.com/stretchr/testify/require"
)
type TestAccountType string
const (
TestAccountDefault TestAccountType = ""
TestAccountSecondary TestAccountType = "SECONDARY_"
TestAccountPremium TestAccountType = "PREMIUM_"
TestAccountSoftDelete TestAccountType = "SOFT_DELETE_"
TestAccountDatalake TestAccountType = "DATALAKE_"
TestAccountImmutable TestAccountType = "IMMUTABLE_"
)
const (
DefaultEndpointSuffix = "core.windows.net/"
DefaultBlobEndpointSuffix = "blob.core.windows.net/"
AccountNameEnvVar = "AZURE_STORAGE_ACCOUNT_NAME"
AccountKeyEnvVar = "AZURE_STORAGE_ACCOUNT_KEY"
DefaultEndpointSuffixEnvVar = "AZURE_STORAGE_ENDPOINT_SUFFIX"
SubscriptionID = "SUBSCRIPTION_ID"
ResourceGroupName = "RESOURCE_GROUP_NAME"
)
const (
FakeStorageAccount = "fakestorage"
FakeStorageURL = "https://fakestorage.blob.core.windows.net"
FakeToken = "faketoken"
)
var (
BlobContentType = "my_type"
BlobContentDisposition = "my_disposition"
BlobCacheControl = "control"
BlobContentLanguage = "my_language"
BlobContentEncoding = "my_encoding"
)
var BasicHeaders = blob.HTTPHeaders{
BlobContentType: &BlobContentType,
BlobContentDisposition: &BlobContentDisposition,
BlobCacheControl: &BlobCacheControl,
BlobContentMD5: nil,
BlobContentLanguage: &BlobContentLanguage,
BlobContentEncoding: &BlobContentEncoding,
}
var BasicMetadata = map[string]*string{"Foo": to.Ptr("bar")}
var BasicBlobTagsMap = map[string]string{
"azure": "blob",
"blob": "sdk",
"sdk": "go",
}
var SpecialCharBlobTagsMap = map[string]string{
"+-./:=_ ": "firsttag",
"tag2": "+-./:=_",
"+-./:=_1": "+-./:=_",
"Microsoft Azure": "Azure Storage",
"Storage+SDK": "SDK/GO",
"GO ": ".Net",
"go": "written in golang",
}
func SetClientOptions(t *testing.T, opts *azcore.ClientOptions) {
opts.Logging.AllowedHeaders = append(opts.Logging.AllowedHeaders, "X-Request-Mismatch", "X-Request-Mismatch-Error")
transport, err := recording.NewRecordingHTTPClient(t, nil)
require.NoError(t, err)
opts.Transport = transport
}
func GetClient(t *testing.T, accountType TestAccountType, options *azblob.ClientOptions) (*azblob.Client, error) {
if options == nil {
options = &azblob.ClientOptions{}
}
SetClientOptions(t, &options.ClientOptions)
cred, err := GetGenericSharedKeyCredential(accountType)
if err != nil {
return nil, err
}
client, err := azblob.NewClientWithSharedKeyCredential("https://"+cred.AccountName()+".blob.core.windows.net/", cred, options)
return client, err
}
func GetServiceClient(t *testing.T, accountType TestAccountType, options *service.ClientOptions) (*service.Client, error) {
if options == nil {
options = &service.ClientOptions{}
}
SetClientOptions(t, &options.ClientOptions)
cred, err := GetGenericSharedKeyCredential(accountType)
if err != nil {
return nil, err
}
serviceClient, err := service.NewClientWithSharedKeyCredential("https://"+cred.AccountName()+".blob.core.windows.net/", cred, options)
return serviceClient, err
}
func GetServiceClientNoCredential(t *testing.T, sasUrl string, options *service.ClientOptions) (*service.Client, error) {
if options == nil {
options = &service.ClientOptions{}
}
SetClientOptions(t, &options.ClientOptions)
serviceClient, err := service.NewClientWithNoCredential(sasUrl, options)
return serviceClient, err
}
func GetGenericTokenCredential() (azcore.TokenCredential, error) {
return credential.New(nil)
}
func GetGenericAccountInfo(accountType TestAccountType) (string, string) {
if recording.GetRecordMode() == recording.PlaybackMode {
return FakeStorageAccount, "ZmFrZQ=="
}
accountNameEnvVar := string(accountType) + AccountNameEnvVar
accountKeyEnvVar := string(accountType) + AccountKeyEnvVar
accountName, _ := GetRequiredEnv(accountNameEnvVar)
accountKey, _ := GetRequiredEnv(accountKeyEnvVar)
return accountName, accountKey
}
func GetGenericSharedKeyCredential(accountType TestAccountType) (*azblob.SharedKeyCredential, error) {
accountName, accountKey := GetGenericAccountInfo(accountType)
if accountName == "" || accountKey == "" {
return nil, errors.New(string(accountType) + AccountNameEnvVar + " and/or " + string(accountType) + AccountKeyEnvVar + " environment variables not specified.")
}
return azblob.NewSharedKeyCredential(accountName, accountKey)
}
func GetGenericConnectionString(accountType TestAccountType) (*string, error) {
accountName, accountKey := GetGenericAccountInfo(accountType)
if accountName == "" || accountKey == "" {
return nil, errors.New(string(accountType) + AccountNameEnvVar + " and/or " + string(accountType) + AccountKeyEnvVar + " environment variables not specified.")
}
connectionString := fmt.Sprintf("DefaultEndpointsProtocol=https;AccountName=%s;AccountKey=%s;EndpointSuffix=core.windows.net/",
accountName, accountKey)
return &connectionString, nil
}
func GetServiceClientFromConnectionString(t *testing.T, accountType TestAccountType, options *service.ClientOptions) (*service.Client, error) {
if options == nil {
options = &service.ClientOptions{}
}
SetClientOptions(t, &options.ClientOptions)
transport, err := recording.NewRecordingHTTPClient(t, nil)
require.NoError(t, err)
options.Transport = transport
cred, err := GetGenericConnectionString(accountType)
if err != nil {
return nil, err
}
svcClient, err := service.NewClientFromConnectionString(*cred, options)
return svcClient, err
}
func GetContainerClient(containerName string, s *service.Client) *container.Client {
return s.NewContainerClient(containerName)
}
func CreateNewContainer(ctx context.Context, _require *require.Assertions, containerName string, serviceClient *service.Client) *container.Client {
containerClient := GetContainerClient(containerName, serviceClient)
_, err := containerClient.Create(ctx, nil)
_require.NoError(err)
return containerClient
}
func DeleteContainer(ctx context.Context, _require *require.Assertions, containerClient *container.Client) {
_, err := containerClient.Delete(ctx, nil)
_require.NoError(err)
}
func GetBlobClient(blockBlobName string, containerClient *container.Client) *blob.Client {
return containerClient.NewBlobClient(blockBlobName)
}
func CreateNewBlobs(ctx context.Context, _require *require.Assertions, blobNames []string, containerClient *container.Client) {
for _, blobName := range blobNames {
CreateNewBlockBlob(ctx, _require, blobName, containerClient)
}
}
func CreateNewBlobsListTier(ctx context.Context, _require *require.Assertions, blobNames []string, containerClient *container.Client, tier *blob.AccessTier) {
for _, blobName := range blobNames {
bbClient := CreateNewBlockBlob(ctx, _require, blobName, containerClient)
_, err := bbClient.SetTier(ctx, *tier, nil)
_require.NoError(err)
}
}
func GetBlockBlobClient(blockBlobName string, containerClient *container.Client) *blockblob.Client {
return containerClient.NewBlockBlobClient(blockBlobName)
}
func CreateNewBlockBlob(ctx context.Context, _require *require.Assertions, blockBlobName string, containerClient *container.Client) *blockblob.Client {
bbClient := GetBlockBlobClient(blockBlobName, containerClient)
_, err := bbClient.Upload(ctx, streaming.NopCloser(strings.NewReader(BlockBlobDefaultData)), nil)
_require.NoError(err)
return bbClient
}
func CreateNewBlockBlobWithCPK(ctx context.Context, _require *require.Assertions, blockBlobName string, containerClient *container.Client, cpkInfo *blob.CPKInfo, cpkScopeInfo *blob.CPKScopeInfo) (bbClient *blockblob.Client) {
bbClient = GetBlockBlobClient(blockBlobName, containerClient)
uploadBlockBlobOptions := blockblob.UploadOptions{
CPKInfo: cpkInfo,
CPKScopeInfo: cpkScopeInfo,
}
cResp, err := bbClient.Upload(ctx, streaming.NopCloser(strings.NewReader(BlockBlobDefaultData)), &uploadBlockBlobOptions)
_require.NoError(err)
_require.Equal(*cResp.IsServerEncrypted, true)
if cpkInfo != nil && recording.GetRecordMode() != recording.PlaybackMode {
_require.EqualValues(cResp.EncryptionKeySHA256, cpkInfo.EncryptionKeySHA256)
}
if cpkScopeInfo != nil && recording.GetRecordMode() != recording.PlaybackMode {
_require.EqualValues(cResp.EncryptionScope, cpkScopeInfo.EncryptionScope)
}
return
}
func GetAppendBlobClient(appendBlobName string, containerClient *container.Client) *appendblob.Client {
return containerClient.NewAppendBlobClient(appendBlobName)
}
func CreateNewAppendBlob(ctx context.Context, _require *require.Assertions, appendBlobName string, containerClient *container.Client) *appendblob.Client {
abClient := GetAppendBlobClient(appendBlobName, containerClient)
_, err := abClient.Create(ctx, nil)
_require.NoError(err)
return abClient
}
// Some tests require setting service properties. It can take up to 30 seconds for the new properties to be reflected across all FEs.
// We will enable the necessary property and try to run the test implementation. If it fails with an error that should be due to
// those changes not being reflected yet, we will wait 30 seconds and try the test again. If it fails this time for any reason,
// we fail the test. It is the responsibility of the testImplFunc to determine which error string indicates the test should be retried.
// There can only be one such string. All errors that cannot be due to this detail should be asserted and not returned as an error string.
func RunTestRequiringServiceProperties(ctx context.Context, _require *require.Assertions, svcClient *service.Client, code string,
enableServicePropertyFunc func(context.Context, *require.Assertions, *service.Client),
testImplFunc func(context.Context, *require.Assertions, *service.Client) error,
disableServicePropertyFunc func(context.Context, *require.Assertions, *service.Client)) {
enableServicePropertyFunc(ctx, _require, svcClient)
defer disableServicePropertyFunc(ctx, _require, svcClient)
err := testImplFunc(ctx, _require, svcClient)
// We cannot assume that the error indicative of slow update will necessarily be a StorageError. As in ListBlobs.
if err != nil && err.Error() == code {
time.Sleep(time.Second * 30)
err = testImplFunc(ctx, _require, svcClient)
_require.NoError(err)
}
}
func EnableSoftDelete(ctx context.Context, _require *require.Assertions, client *service.Client) {
days := int32(1)
_, err := client.SetProperties(ctx, &service.SetPropertiesOptions{
DeleteRetentionPolicy: &service.RetentionPolicy{Enabled: to.Ptr(true), Days: &days}})
_require.NoError(err)
}
func DisableSoftDelete(ctx context.Context, _require *require.Assertions, client *service.Client) {
_, err := client.SetProperties(ctx, &service.SetPropertiesOptions{DeleteRetentionPolicy: &service.RetentionPolicy{Enabled: to.Ptr(false)}})
_require.NoError(err)
}
func ListBlobsCount(ctx context.Context, _require *require.Assertions, listPager *runtime.Pager[container.ListBlobsFlatResponse], ctr int) {
found := make([]*container.BlobItem, 0)
for listPager.More() {
resp, err := listPager.NextPage(ctx)
_require.NoError(err)
if err != nil {
break
}
found = append(found, resp.Segment.BlobItems...)
}
_require.Len(found, ctr)
}
func GetServiceSAS(containerName string, permissions sas.BlobPermissions) (string, error) {
credential, err := GetGenericSharedKeyCredential(TestAccountDefault)
if err != nil {
return "", err
}
sasQueryParams, err := sas.BlobSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC(),
ExpiryTime: time.Now().UTC().Add(2 * time.Hour),
Permissions: permissions.String(),
ContainerName: containerName,
}.SignWithSharedKey(credential)
if err != nil {
return "", err
}
return sasQueryParams.Encode(), nil
}
func GetUserDelegationSAS(svcClient *service.Client, containerName string, permissions sas.BlobPermissions) (string, error) {
// Set current and past time and create key
now := time.Now().UTC().Add(-10 * time.Second)
expiry := now.Add(2 * time.Hour)
info := service.KeyInfo{
Start: to.Ptr(now.UTC().Format(sas.TimeFormat)),
Expiry: to.Ptr(expiry.UTC().Format(sas.TimeFormat)),
}
udc, err := svcClient.GetUserDelegationCredential(context.Background(), info, nil)
if err != nil {
return "", err
}
// Create Blob Signature Values with desired permissions and sign with user delegation credential
sasQueryParams, err := sas.BlobSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC().Add(time.Second * -10),
ExpiryTime: time.Now().UTC().Add(15 * time.Minute),
Permissions: permissions.String(),
ContainerName: containerName,
}.SignWithUserDelegation(udc)
if err != nil {
return "", err
}
return sasQueryParams.Encode(), nil
}
func GetAccountSAS(permissions sas.AccountPermissions, resourceTypes sas.AccountResourceTypes) (string, error) {
credential, err := GetGenericSharedKeyCredential(TestAccountDefault)
if err != nil {
return "", err
}
sasQueryParams, err := sas.AccountSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC(),
ExpiryTime: time.Now().UTC().Add(1 * time.Hour),
Permissions: permissions.String(),
ResourceTypes: resourceTypes.String(),
}.SignWithSharedKey(credential)
if err != nil {
return "", err
}
return sasQueryParams.Encode(), nil
}
func DeleteContainerUsingManagementClient(_require *require.Assertions, accountType TestAccountType, containerName string) {
if recording.GetRecordMode() == recording.PlaybackMode {
return
}
accountName, err := GetRequiredEnv(string(accountType) + AccountNameEnvVar)
_require.NoError(err)
subscriptionID, err := GetRequiredEnv(SubscriptionID)
_require.NoError(err)
resourceGroupName, err := GetRequiredEnv(ResourceGroupName)
_require.NoError(err)
cred, err := credential.New(nil)
_require.NoError(err)
managementClient, err := armstorage.NewBlobContainersClient(subscriptionID, cred, nil)
_require.NoError(err)
_, err = managementClient.Delete(context.Background(), resourceGroupName, accountName, containerName, nil)
_require.NoError(err)
}