sdk/data/azcosmos/cosmos_client.go (510 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azcosmos
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)
const (
apiVersion = "2020-11-05"
)
// Client is used to interact with the Azure Cosmos DB database service.
type Client struct {
endpoint string
internal *azcore.Client
gem *globalEndpointManager
endpointUrl *url.URL
}
// Endpoint used to create the client.
func (c *Client) Endpoint() string {
return c.endpoint
}
// NewClientWithKey creates a new instance of Cosmos client with shared key authentication. It uses the default pipeline configuration.
// endpoint - The cosmos service endpoint to use.
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
endpointUrl, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
preferredRegions := []string{}
enableCrossRegionRetries := true
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0, enableCrossRegionRetries)
if err != nil {
return nil, err
}
internalClient, err := newClient(newSharedKeyCredPolicy(cred), gem, o)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem}, nil
}
// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
// endpoint - The cosmos service endpoint to use.
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (*Client, error) {
endpointUrl, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
scope, err := createScopeFromEndpoint(endpointUrl)
if err != nil {
return nil, err
}
preferredRegions := []string{}
enableCrossRegionRetries := true
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0, enableCrossRegionRetries)
if err != nil {
return nil, err
}
internalClient, err := newClient(newCosmosBearerTokenPolicy(cred, scope, nil), gem, o)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, endpointUrl: endpointUrl, internal: internalClient, gem: gem}, nil
}
// NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration.
// connectionString - The cosmos service connection string.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientFromConnectionString(connectionString string, o *ClientOptions) (*Client, error) {
const (
accountEndpoint = "AccountEndpoint"
accountKey = "AccountKey"
)
splits := strings.SplitN(connectionString, ";", 2)
if len(splits) < 2 {
return nil, errors.New("failed parsing connection string due to it not consist of two parts separated by ';'")
}
var endpoint string
var cred KeyCredential
for _, split := range splits {
keyVal := strings.SplitN(split, "=", 2)
if len(keyVal) < 2 {
return nil, fmt.Errorf("failed parsing connection string due to unmatched key value separated by '='")
}
switch {
case strings.EqualFold(accountEndpoint, keyVal[0]):
endpoint = keyVal[1]
case strings.EqualFold(accountKey, keyVal[0]):
c, err := NewKeyCredential(strings.TrimSuffix(keyVal[1], ";"))
if err != nil {
return nil, err
}
cred = c
}
}
return NewClientWithKey(endpoint, cred, o)
}
func newClient(authPolicy policy.Policy, gem *globalEndpointManager, options *ClientOptions) (*azcore.Client, error) {
if options == nil {
options = &ClientOptions{}
}
return azcore.NewClient(moduleName, serviceLibVersion,
azruntime.PipelineOptions{
AllowedHeaders: getAllowedHeaders(),
PerCall: []policy.Policy{
&headerPolicies{
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
},
&globalEndpointManagerPolicy{gem: gem},
},
PerRetry: []policy.Policy{
authPolicy,
&clientRetryPolicy{gem: gem},
},
Tracing: azruntime.TracingOptions{
Namespace: "Microsoft.DocumentDB",
},
},
&options.ClientOptions)
}
func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
return azruntime.NewPipeline(moduleName, serviceLibVersion,
azruntime.PipelineOptions{
AllowedHeaders: getAllowedHeaders(),
PerRetry: []policy.Policy{
authPolicy,
},
},
&options.ClientOptions)
}
func createScopeFromEndpoint(endpoint *url.URL) ([]string, error) {
return []string{fmt.Sprintf("%s://%s/.default", endpoint.Scheme, endpoint.Hostname())}, nil
}
// NewDatabase returns a struct that represents a database and allows database level operations.
// id - The id of the database.
func (c *Client) NewDatabase(id string) (*DatabaseClient, error) {
if id == "" {
return nil, errors.New("id is required")
}
return newDatabase(id, c)
}
// NewContainer returns a struct that represents a container and allows container level operations.
// databaseId - The id of the database.
// containerId - The id of the container.
func (c *Client) NewContainer(databaseId string, containerId string) (*ContainerClient, error) {
if databaseId == "" {
return nil, errors.New("databaseId is required")
}
if containerId == "" {
return nil, errors.New("containerId is required")
}
db, err := newDatabase(databaseId, c)
if err != nil {
return nil, err
}
return db.NewContainer(containerId)
}
// CreateDatabase creates a new database.
// ctx - The context for the request.
// databaseProperties - The definition of the database
// o - Options for the create database operation.
func (c *Client) CreateDatabase(
ctx context.Context,
databaseProperties DatabaseProperties,
o *CreateDatabaseOptions) (DatabaseResponse, error) {
var err error
spanName, err := getSpanNameForDatabases(c.accountEndpointUrl(), operationTypeCreate, resourceTypeDatabase, databaseProperties.ID)
if err != nil {
return DatabaseResponse{}, err
}
ctx, endSpan := azruntime.StartSpan(ctx, spanName.name, c.internal.Tracer(), &spanName.options)
defer func() { endSpan(err) }()
if o == nil {
o = &CreateDatabaseOptions{}
}
returnResponse := true
h := &headerOptionsOverride{
enableContentResponseOnWrite: &returnResponse,
}
operationContext := pipelineRequestOptions{
resourceType: resourceTypeDatabase,
resourceAddress: "",
isWriteOperation: true,
headerOptionsOverride: h,
}
path, err := generatePathForNameBased(resourceTypeDatabase, "", true)
if err != nil {
return DatabaseResponse{}, err
}
azResponse, err := c.sendPostRequest(
path,
ctx,
databaseProperties,
operationContext,
nil,
o.ThroughputProperties.addHeadersToRequest)
if err != nil {
return DatabaseResponse{}, err
}
response, err := newDatabaseResponse(azResponse)
return response, err
}
// NewQueryDatabasesPager executes query for databases.
// query - The SQL query to execute.
// o - Options for the operation.
func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) *azruntime.Pager[QueryDatabasesResponse] {
queryOptions := &QueryDatabasesOptions{}
if o != nil {
originalOptions := *o
queryOptions = &originalOptions
}
operationContext := pipelineRequestOptions{
resourceType: resourceTypeDatabase,
resourceAddress: "",
}
path, _ := generatePathForNameBased(resourceTypeDatabase, operationContext.resourceAddress, true)
return azruntime.NewPager(azruntime.PagingHandler[QueryDatabasesResponse]{
More: func(page QueryDatabasesResponse) bool {
return page.ContinuationToken != nil
},
Fetcher: func(ctx context.Context, page *QueryDatabasesResponse) (QueryDatabasesResponse, error) {
var err error
spanName, err := getSpanNameForClient(c.accountEndpointUrl(), operationTypeQuery, resourceTypeDatabase, c.accountEndpointUrl().Hostname())
if err != nil {
return QueryDatabasesResponse{}, err
}
ctx, endSpan := azruntime.StartSpan(ctx, spanName.name, c.internal.Tracer(), &spanName.options)
defer func() { endSpan(err) }()
if page != nil {
if page.ContinuationToken != nil {
// Use the previous page continuation if available
queryOptions.ContinuationToken = page.ContinuationToken
}
}
azResponse, err := c.sendQueryRequest(
path,
ctx,
query,
queryOptions.QueryParameters,
operationContext,
queryOptions,
nil)
if err != nil {
return QueryDatabasesResponse{}, err
}
return newDatabasesQueryResponse(azResponse)
},
})
}
func (c *Client) sendPostRequest(
path string,
ctx context.Context,
content interface{},
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodPost, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
err = c.attachContent(content, req)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendQueryRequest(
path string,
ctx context.Context,
query string,
parameters []QueryParameter,
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodPost, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
err = azruntime.MarshalAsJSON(req, queryBody{
Query: query,
Parameters: parameters,
})
if err != nil {
return nil, err
}
req.Raw().Header.Add(cosmosHeaderQuery, "True")
// Override content type for query
req.Raw().Header.Set(headerContentType, cosmosHeaderValuesQuery)
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendPutRequest(
path string,
ctx context.Context,
content interface{},
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodPut, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
err = c.attachContent(content, req)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendGetRequest(
path string,
ctx context.Context,
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodGet, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendDeleteRequest(
path string,
ctx context.Context,
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodDelete, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendBatchRequest(
ctx context.Context,
path string,
batch []batchOperation,
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodPost, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
err = c.attachContent(batch, req)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) sendPatchRequest(
path string,
ctx context.Context,
content interface{},
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*http.Response, error) {
req, err := c.createRequest(path, ctx, http.MethodPatch, operationContext, requestOptions, requestEnricher)
if err != nil {
return nil, err
}
err = c.attachContent(content, req)
if err != nil {
return nil, err
}
return c.executeAndEnsureSuccessResponse(ctx, req)
}
func (c *Client) createRequest(
path string,
ctx context.Context,
method string,
operationContext pipelineRequestOptions,
requestOptions cosmosRequestOptions,
requestEnricher func(*policy.Request)) (*policy.Request, error) {
// todo: endpoint will be set originally by globalendpointmanager
finalURL := c.endpoint
if path != "" {
finalURL = azruntime.JoinPaths(c.endpoint, path)
}
req, err := azruntime.NewRequest(ctx, method, finalURL)
if err != nil {
return nil, err
}
if requestOptions != nil {
headers := requestOptions.toHeaders()
if headers != nil {
for k, v := range *headers {
req.Raw().Header.Set(k, v)
}
}
}
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, apiVersion)
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)
req.SetOperationValue(operationContext)
if requestEnricher != nil {
requestEnricher(req)
}
return req, nil
}
func (c *Client) attachContent(content interface{}, req *policy.Request) error {
var err error
switch v := content.(type) {
case []byte:
// If its a raw byte array, we can just set the body
err = req.SetBody(streaming.NopCloser(bytes.NewReader(v)), "application/json")
default:
// Otherwise, we need to marshal it
err = azruntime.MarshalAsJSON(req, content)
}
if err != nil {
return err
}
return nil
}
func (c *Client) executeAndEnsureSuccessResponse(ctx context.Context, request *policy.Request) (*http.Response, error) {
log.Write(azlog.EventResponse, fmt.Sprintf("\n===== Client preferred regions:\n%v\n=====\n", c.gem.preferredLocations))
response, err := c.internal.Pipeline().Do(request)
if err != nil {
return nil, err
}
c.addResponseValuesToSpan(ctx, response)
successResponse := (response.StatusCode >= 200 && response.StatusCode < 300) || response.StatusCode == 304
if successResponse {
return response, nil
}
return nil, azruntime.NewResponseErrorWithErrorCode(response, response.Status)
}
func (c *Client) accountEndpointUrl() *url.URL {
return c.endpointUrl
}
func (c *Client) addResponseValuesToSpan(ctx context.Context, resp *http.Response) {
span := c.internal.Tracer().SpanFromContext(ctx)
span.SetAttributes(
tracing.Attribute{Key: "db.cosmosdb.request_charge", Value: newResponse(resp).RequestCharge},
tracing.Attribute{Key: "db.cosmosdb.status_code", Value: resp.StatusCode},
)
}
type pipelineRequestOptions struct {
headerOptionsOverride *headerOptionsOverride
resourceType resourceType
resourceAddress string
isRidBased bool
isWriteOperation bool
}
func getAllowedHeaders() []string {
return []string{
cosmosHeaderRequestCharge,
cosmosHeaderActivityId,
cosmosHeaderEtag,
cosmosHeaderSubstatus,
cosmosHeaderPopulateQuotaInfo,
cosmosHeaderPreTriggerInclude,
cosmosHeaderPostTriggerInclude,
cosmosHeaderIndexingDirective,
cosmosHeaderSessionToken,
cosmosHeaderConsistencyLevel,
cosmosHeaderPrefer,
cosmosHeaderIsUpsert,
cosmosHeaderOfferThroughput,
cosmosHeaderOfferAutoscale,
cosmosHeaderQuery,
cosmosHeaderOfferReplacePending,
cosmosHeaderOfferMinimumThroughput,
cosmosHeaderResponseContinuationTokenLimitInKb,
cosmosHeaderEnableScanInQuery,
cosmosHeaderMaxItemCount,
cosmosHeaderContinuationToken,
cosmosHeaderPopulateIndexMetrics,
cosmosHeaderPopulateQueryMetrics,
cosmosHeaderQueryMetrics,
cosmosHeaderIndexUtilization,
cosmosHeaderCorrelatedActivityId,
cosmosHeaderIsBatchRequest,
cosmosHeaderIsBatchAtomic,
cosmosHeaderIsBatchOrdered,
cosmosHeaderSDKSupportedCapabilities,
headerXmsDate,
headerContentType,
headerIfMatch,
headerIfNoneMatch,
headerXmsVersion,
headerContentLocation,
headerXmsGatewayVersion,
headerLsn,
headerXmsCosmosLlsn,
headerXmsCosmosItemLlsn,
headerXmsItemLsn,
headerXmsCosmosQuorumAckedLlsn,
headerXmsCurrentReplicaSetSize,
headerXmsCurrentWriteQuorum,
headerXmsGlobalCommittedLsn,
headerXmsLastStateChangeUtc,
headerXmsNumberOfReadRegions,
headerXmsQuorumAckedLsn,
headerXmsRequestDurationMs,
headerXmsResourceQuota,
headerXmsResourceUsage,
headerXmsSchemaVersion,
headerXmsServiceVersion,
headerXmsTransportRequestId,
headerXmsXpRole,
headerCollectionPartitionIndex,
headerCollectionServiceIndex,
headerXmsDocumentDbPartitionKeyRangeId,
cosmosHeaderPhysicalPartitionId,
headerStrictTransportSecurity,
headerXmsDatabaseAccountConsumedMb,
headerXmsDatabaseAccountProvisionedMb,
headerXmsDatabaseAccountReservedMb,
headerXmsMaxMediaStorageUsageMb,
headerXmsMediaStorageUsageMb,
headerXmsContentPath,
headerXmsAltContentPath,
cosmosHeaderMaxContentLength,
cosmosHeaderIsPartitionKeyDeletePending,
cosmosHeaderQueryExecutionInfo,
headerXmsItemCount,
}
}