sdk/data/azcosmos/shared_key_credential.go (97 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azcosmos
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)
// NewKeyCredential creates an KeyCredential containing the
// account's primary or secondary key.
func NewKeyCredential(accountKey string) (KeyCredential, error) {
c := KeyCredential{}
if err := c.Update(accountKey); err != nil {
return c, err
}
return c, nil
}
// KeyCredential contains an account's name and its primary or secondary key.
// It is immutable making it shareable and goroutine-safe.
type KeyCredential struct {
// Only the KeyCredential method should set these; all other methods should treat them as read-only
accountKey atomic.Value // []byte
}
// Update replaces the existing account key with the specified account key.
func (c *KeyCredential) Update(accountKey string) error {
bytes, err := base64.StdEncoding.DecodeString(accountKey)
if err != nil {
return fmt.Errorf("decode account key: %w", err)
}
c.accountKey.Store(bytes)
return nil
}
// computeHMACSHA256 generates a hash signature for an HTTP request
func (c *KeyCredential) computeHMACSHA256(s string) (base64String string) {
h := hmac.New(sha256.New, c.accountKey.Load().([]byte))
_, _ = h.Write([]byte(s))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func (c *KeyCredential) buildCanonicalizedAuthHeaderFromRequest(req *policy.Request) (string, error) {
var opValues pipelineRequestOptions
value := ""
if req.OperationValue(&opValues) {
resourceTypePath, err := getResourcePath(opValues.resourceType)
if err != nil {
return "", err
}
resourceAddress := opValues.resourceAddress
if opValues.isRidBased {
resourceAddress = strings.ToLower(resourceAddress)
}
isDatabaseAccount := opValues.resourceType == resourceTypeDatabaseAccount
value = c.buildCanonicalizedAuthHeader(isDatabaseAccount, req.Raw().Method, resourceTypePath, resourceAddress, req.Raw().Header.Get(headerXmsDate), "master", "1.0")
}
return value, nil
}
// where date is like time.RFC1123 but hard-codes GMT as the time zone
func (c *KeyCredential) buildCanonicalizedAuthHeader(isDatabaseAccount bool, method, resourceTypePath, resourceAddress, xmsDate, tokenType, version string) string {
if method == "" || (resourceTypePath == "" && !isDatabaseAccount) {
return ""
}
resourceAddress, _ = url.PathUnescape(resourceAddress)
// https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#constructkeytoken
stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceTypePath), "\n", resourceAddress, "\n", strings.ToLower(xmsDate), "\n", "", "\n")
signature := c.computeHMACSHA256(stringToSign)
return url.QueryEscape(join("type=" + tokenType + "&ver=" + version + "&sig=" + signature))
}
type sharedKeyCredPolicy struct {
cred KeyCredential
}
func newSharedKeyCredPolicy(cred KeyCredential) *sharedKeyCredPolicy {
s := &sharedKeyCredPolicy{
cred: cred,
}
return s
}
func (s *sharedKeyCredPolicy) Do(req *policy.Request) (*http.Response, error) {
// Add a x-ms-date header if it doesn't already exist
if d := req.Raw().Header.Get(headerXmsDate); d == "" {
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
}
authHeader, err := s.cred.buildCanonicalizedAuthHeaderFromRequest(req)
if err != nil {
return nil, err
}
if authHeader != "" {
req.Raw().Header.Set(headerAuthorization, authHeader)
}
response, err := req.Next()
if err != nil && response != nil && response.StatusCode == http.StatusForbidden {
// Service failed to authenticate request, log it
log.Write(azlog.EventResponse, "===== HTTP Forbidden status, Authorization:\n"+authHeader+"\n=====\n")
}
return response, err
}
func join(strs ...string) string {
var sb strings.Builder
for _, str := range strs {
sb.WriteString(str)
}
return sb.String()
}