internal/database/database.go (392 lines of code) (raw):

// Copyright 2025 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package database import ( "context" "errors" "fmt" "iter" "net/http" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/ARO-HCP/internal/api/arm" ) const ( billingContainer = "Billing" locksContainer = "Locks" resourcesContainer = "Resources" operationTimeToLive = 604800 // 7 days ) var ErrNotFound = errors.New("not found") func isResponseError(err error, statusCode int) bool { var responseError *azcore.ResponseError return errors.As(err, &responseError) && responseError.StatusCode == statusCode } // NewPartitionKey creates a partition key from an Azure subscription ID. func NewPartitionKey(subscriptionID string) azcosmos.PartitionKey { return azcosmos.NewPartitionKeyString(strings.ToLower(subscriptionID)) } type DBClientIteratorItem[T DocumentProperties] iter.Seq2[string, *T] type DBClientIterator[T DocumentProperties] interface { Items(ctx context.Context) DBClientIteratorItem[T] GetContinuationToken() string GetError() error } // DBClientListActiveOperationDocsOptions allows for limiting the results of DBClient.ListActiveOperationDocs. type DBClientListActiveOperationDocsOptions struct { // Request matches the type of asynchronous operation requested Request *OperationRequest // ExternalID matches (case-insensitively) the Azure resource ID of the cluster or node pool ExternalID *azcorearm.ResourceID } // DBClient provides a customized interface to the Cosmos DB containers used by the // ARO-HCP resource provider. type DBClient interface { // DBConnectionTest verifies the database is reachable. Intended for use in health checks. DBConnectionTest(ctx context.Context) error // GetLockClient returns a LockClient, or nil if the DBClient does not support a LockClient. GetLockClient() *LockClient // GetResourceDoc queries the "Resources" container for a cluster or node pool document with a // matching resourceID. GetResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) (*ResourceDocument, error) // CreateResourceDoc creates a new cluster or node pool document in the "Resources" container. CreateResourceDoc(ctx context.Context, doc *ResourceDocument) error // UpdateResourceDoc updates a cluster or node pool document in the "Resources" container by // first fetching the document and passing it to the provided callback for modifications to be // applied. It then attempts to replace the existing document with the modified document and an // "etag" precondition. Upon a precondition failure the function repeats for a limited number of // times before giving up. // // The callback function should return true if modifications were applied, signaling to proceed // with the document replacement. The boolean return value reflects this: returning true if the // document was successfully replaced, or false with or without an error to indicate no change. UpdateResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID, callback func(*ResourceDocument) bool) (bool, error) // DeleteResourceDoc deletes a cluster or node pool document in the "Resources" container. // If no matching document is found, DeleteResourceDoc returns nil as though it had succeeded. DeleteResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) error // ListResourceDocs returns an iterator that searches for cluster or node pool documents in // the "Resources" container that match the given resource ID prefix. The prefix must include // a subscription ID so the correct partition key can be inferred. // // Note that ListResourceDocs does not perform the search, but merely prepares an iterator to // do so. Hence the lack of a Context argument. The search is performed by calling Items() on // the iterator in a ranged for loop. // // maxItems can limit the number of items returned at once. A negative value will cause the // returned iterator to yield all matching documents. A positive value will cause the returned // iterator to include a continuation token if additional items are available. The continuation // token can be supplied on a subsequent call to obtain those additional items. ListResourceDocs(prefix *azcorearm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator[ResourceDocument] // GetOperationDoc retrieves an asynchronous operation document from the "Resources" container. GetOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*OperationDocument, error) // CreateResourceDoc creates a new asynchronous operation document in the "Resources" container. CreateOperationDoc(ctx context.Context, doc *OperationDocument) (string, error) // UpdateOperationDoc updates an asynchronous operation document in the "Resources" container // by first fetching the document and passing it to the provided callback for modifications to // be applied. It then attempts to replace the existing document with the modified document and // an "etag" precondition. Upon a precondition failure the function repeats for a limited number // of times before giving up. // // The callback function should return true if modifications were applied, signaling to proceed // with the document replacement. The boolean return value reflects this: returning true if the // document was successfully replaced, or false with or without an error to indicate no change. UpdateOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*OperationDocument) bool) (bool, error) // ListActiveOperationDocs returns an iterator that searches for asynchronous operation documents // with a non-terminal status in the "Resources" container under the given partition key. The // options argument can further limit the search to documents that match the provided values. // // Note that ListActiveOperationDocs does not perform the search, but merely prepares an iterator // to do so. Hence the lack of a Context argument. The search is performed by calling Items() on // the iterator in a ranged for loop. ListActiveOperationDocs(pk azcosmos.PartitionKey, options *DBClientListActiveOperationDocsOptions) DBClientIterator[OperationDocument] // GetSubscriptionDoc retrieves a subscription document from the "Resources" container. GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*arm.Subscription, error) // CreateSubscriptionDoc creates a new subscription document in the "Resources" container. CreateSubscriptionDoc(ctx context.Context, subscriptionID string, subscription *arm.Subscription) error // UpdateSubscriptionDoc updates a subscription document in the "Resources" container by first // fetching the document and passing it to the provided callback for modifications to be applied. // It then attempts to replace the existing document with the modified document an an "etag" // precondition. Upon a precondition failure the function repeats for a limited number of times // before giving up. // // The callback function should return true if modifications were applied, signaling to proceed // with the document replacement. The boolean return value reflects this: returning true if the // document was successfully replaced, or false with or without an error to indicate no change. UpdateSubscriptionDoc(ctx context.Context, subscriptionID string, callback func(*arm.Subscription) bool) (bool, error) // ListAllSubscriptionDocs() returns an iterator that searches for all subscription documents in // the "Resources" container. Since the "Resources" container is partitioned by subscription ID, // there will only be one subscription document per logical partition. Thus, this method enables // iterating over all the logical partitions in the "Resources" container. // // Note that ListAllSubscriptionDocs does not perform the search, but merely prepares an iterator // to do so. Hence the lack of a Context argument. The search is performed by calling Items() on // the iterator in a ranged for loop. ListAllSubscriptionDocs() DBClientIterator[arm.Subscription] } var _ DBClient = &cosmosDBClient{} // cosmosDBClient defines the needed values to perform CRUD operations against Cosmos DB. type cosmosDBClient struct { database *azcosmos.DatabaseClient resources *azcosmos.ContainerClient lockClient *LockClient } // NewDBClient instantiates a DBClient from a Cosmos DatabaseClient instance // targeting the Frontends async database. func NewDBClient(ctx context.Context, database *azcosmos.DatabaseClient) (DBClient, error) { // NewContainer only fails if the container ID argument is // empty, so we can safely disregard the error return value. resources, _ := database.NewContainer(resourcesContainer) locks, _ := database.NewContainer(locksContainer) lockClient, err := NewLockClient(ctx, locks) if err != nil { return nil, err } return &cosmosDBClient{ database: database, resources: resources, lockClient: lockClient, }, nil } func (d *cosmosDBClient) DBConnectionTest(ctx context.Context) error { if _, err := d.database.Read(ctx, nil); err != nil { return fmt.Errorf("failed to read Cosmos database information during healthcheck: %v", err) } return nil } func (d *cosmosDBClient) GetLockClient() *LockClient { return d.lockClient } func (d *cosmosDBClient) getResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) (*typedDocument, *ResourceDocument, error) { pk := NewPartitionKey(resourceID.SubscriptionID) const query = "SELECT * FROM c WHERE STRINGEQUALS(c.resourceType, @resourceType, true) AND STRINGEQUALS(c.properties.resourceId, @resourceId, true)" opt := azcosmos.QueryOptions{ PageSizeHint: 1, QueryParameters: []azcosmos.QueryParameter{ { Name: "@resourceType", Value: resourceID.ResourceType.String(), }, { Name: "@resourceId", Value: resourceID.String(), }, }, } queryPager := d.resources.NewQueryItemsPager(query, pk, &opt) for queryPager.More() { queryResponse, err := queryPager.NextPage(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to advance page while querying Resources container for '%s': %w", resourceID, err) } for _, item := range queryResponse.Items { typedDoc, innerDoc, err := typedDocumentUnmarshal[ResourceDocument](item) if err != nil { return nil, nil, fmt.Errorf("failed to unmarshal Resources container item for '%s': %w", resourceID, err) } return typedDoc, innerDoc, nil } } return nil, nil, fmt.Errorf("failed to read Resources container item for '%s': %w", resourceID, ErrNotFound) } func (d *cosmosDBClient) GetResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) (*ResourceDocument, error) { _, innerDoc, err := d.getResourceDoc(ctx, resourceID) if err != nil { return nil, err } // Replace the key field from Cosmos with the given resourceID, // which typically comes from the URL. This helps preserve the // casing of the resource group and resource name from the URL // to meet RPC requirements: // // Put Resource | Arguments // // The resource group names and resource names should be matched // case insensitively. ... Additionally, the Resource Provier must // preserve the casing provided by the user. The service must return // the most recently specified casing to the client and must not // normalize or return a toupper or tolower form of the resource // group or resource name. The resource group name and resource // name must come from the URL and not the request body. innerDoc.ResourceID = resourceID return innerDoc, nil } func (d *cosmosDBClient) CreateResourceDoc(ctx context.Context, doc *ResourceDocument) error { typedDoc := newTypedDocument(doc.ResourceID.SubscriptionID, doc.ResourceID.ResourceType) data, err := typedDocumentMarshal(typedDoc, doc) if err != nil { return fmt.Errorf("failed to marshal Resources container item for '%s': %w", doc.ResourceID, err) } _, err = d.resources.CreateItem(ctx, typedDoc.getPartitionKey(), data, nil) if err != nil { return fmt.Errorf("failed to create Resources container item for '%s': %w", doc.ResourceID, err) } return nil } func (d *cosmosDBClient) UpdateResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID, callback func(*ResourceDocument) bool) (bool, error) { var err error options := &azcosmos.ItemOptions{} for try := 0; try < 5; try++ { var typedDoc *typedDocument var innerDoc *ResourceDocument var data []byte typedDoc, innerDoc, err = d.getResourceDoc(ctx, resourceID) if err != nil { return false, err } if !callback(innerDoc) { return false, nil } data, err = typedDocumentMarshal(typedDoc, innerDoc) if err != nil { return false, fmt.Errorf("failed to marshal Resources container item for '%s': %w", resourceID, err) } options.IfMatchEtag = &typedDoc.CosmosETag _, err = d.resources.ReplaceItem(ctx, typedDoc.getPartitionKey(), typedDoc.ID, data, options) if err == nil { return true, nil } var responseError *azcore.ResponseError err = fmt.Errorf("failed to replace Resources container item for '%s': %w", resourceID, err) if !errors.As(err, &responseError) || responseError.StatusCode != http.StatusPreconditionFailed { return false, err } } return false, err } func (d *cosmosDBClient) DeleteResourceDoc(ctx context.Context, resourceID *azcorearm.ResourceID) error { typedDoc, _, err := d.getResourceDoc(ctx, resourceID) if err != nil { if errors.Is(err, ErrNotFound) { return nil } return err } _, err = d.resources.DeleteItem(ctx, typedDoc.getPartitionKey(), typedDoc.ID, nil) if err != nil { return fmt.Errorf("failed to delete Resources container item for '%s': %w", resourceID, err) } return nil } func (d *cosmosDBClient) ListResourceDocs(prefix *azcorearm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator[ResourceDocument] { pk := NewPartitionKey(prefix.SubscriptionID) // XXX The Cosmos DB REST API gives special meaning to -1 for "x-ms-max-item-count" // but it's not clear if it treats all negative values equivalently. The Go SDK // passes the PageSizeHint value as provided so normalize negative values to -1 // to be safe. maxItems = max(maxItems, -1) const query = "SELECT * FROM c WHERE STARTSWITH(c.properties.resourceId, @prefix, true)" opt := azcosmos.QueryOptions{ PageSizeHint: maxItems, ContinuationToken: continuationToken, QueryParameters: []azcosmos.QueryParameter{ { Name: "@prefix", Value: prefix.String() + "/", }, }, } pager := d.resources.NewQueryItemsPager(query, pk, &opt) if maxItems > 0 { return newQueryItemsSinglePageIterator[ResourceDocument](pager) } else { return newQueryItemsIterator[ResourceDocument](pager) } } func (d *cosmosDBClient) getOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*typedDocument, *OperationDocument, error) { // Make sure lookup keys are lowercase. operationID = strings.ToLower(operationID) response, err := d.resources.ReadItem(ctx, pk, operationID, nil) if err != nil { if isResponseError(err, http.StatusNotFound) { err = ErrNotFound } return nil, nil, fmt.Errorf("failed to read Operations container item for '%s': %w", operationID, err) } typedDoc, innerDoc, err := typedDocumentUnmarshal[OperationDocument](response.Value) if err != nil { return nil, nil, fmt.Errorf("failed to unmarshal Operations container item for '%s': %w", operationID, err) } return typedDoc, innerDoc, nil } func (d *cosmosDBClient) GetOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string) (*OperationDocument, error) { _, innerDoc, err := d.getOperationDoc(ctx, pk, operationID) return innerDoc, err } func (d *cosmosDBClient) CreateOperationDoc(ctx context.Context, doc *OperationDocument) (string, error) { // Make sure partition key is lowercase. subscriptionID := strings.ToLower(doc.ExternalID.SubscriptionID) typedDoc := newTypedDocument(subscriptionID, OperationResourceType) typedDoc.TimeToLive = operationTimeToLive data, err := typedDocumentMarshal(typedDoc, doc) if err != nil { return "", fmt.Errorf("failed to marshal Operations container item for '%s': %w", typedDoc.ID, err) } _, err = d.resources.CreateItem(ctx, typedDoc.getPartitionKey(), data, nil) if err != nil { return "", fmt.Errorf("failed to create Operations container item for '%s': %w", typedDoc.ID, err) } return typedDoc.ID, nil } func (d *cosmosDBClient) UpdateOperationDoc(ctx context.Context, pk azcosmos.PartitionKey, operationID string, callback func(*OperationDocument) bool) (bool, error) { var err error options := &azcosmos.ItemOptions{} for try := 0; try < 5; try++ { var typedDoc *typedDocument var innerDoc *OperationDocument var data []byte typedDoc, innerDoc, err = d.getOperationDoc(ctx, pk, operationID) if err != nil { return false, err } if !callback(innerDoc) { return false, nil } data, err = typedDocumentMarshal(typedDoc, innerDoc) if err != nil { return false, fmt.Errorf("failed to marshal Operations container item for '%s': %w", operationID, err) } options.IfMatchEtag = &typedDoc.CosmosETag _, err = d.resources.ReplaceItem(ctx, pk, typedDoc.ID, data, options) if err == nil { return true, nil } var responseError *azcore.ResponseError err = fmt.Errorf("failed to replace Operations container item for '%s': %w", operationID, err) if !errors.As(err, &responseError) || responseError.StatusCode != http.StatusPreconditionFailed { return false, err } } return false, err } func (d *cosmosDBClient) ListActiveOperationDocs(pk azcosmos.PartitionKey, options *DBClientListActiveOperationDocsOptions) DBClientIterator[OperationDocument] { var queryOptions azcosmos.QueryOptions query := fmt.Sprintf( "SELECT * FROM c WHERE STRINGEQUALS(c.resourceType, %q, true) "+ "AND NOT ARRAYCONTAINS([%q, %q, %q], c.properties.status)", OperationResourceType.String(), arm.ProvisioningStateSucceeded, arm.ProvisioningStateFailed, arm.ProvisioningStateCanceled) if options != nil { if options.Request != nil { query += " AND c.properties.request == @request" queryParameter := azcosmos.QueryParameter{ Name: "@request", Value: string(*options.Request), } queryOptions.QueryParameters = append(queryOptions.QueryParameters, queryParameter) } if options.ExternalID != nil { query += " AND STRINGEQUALS(c.properties.externalId, @externalId, true)" queryParameter := azcosmos.QueryParameter{ Name: "@externalId", Value: options.ExternalID.String(), } queryOptions.QueryParameters = append(queryOptions.QueryParameters, queryParameter) } } pager := d.resources.NewQueryItemsPager(query, pk, &queryOptions) return newQueryItemsIterator[OperationDocument](pager) } func (d *cosmosDBClient) getSubscriptionDoc(ctx context.Context, subscriptionID string) (*typedDocument, *arm.Subscription, error) { // Make sure lookup keys are lowercase. subscriptionID = strings.ToLower(subscriptionID) pk := NewPartitionKey(subscriptionID) response, err := d.resources.ReadItem(ctx, pk, subscriptionID, nil) if err != nil { if isResponseError(err, http.StatusNotFound) { err = ErrNotFound } return nil, nil, fmt.Errorf("failed to read Subscriptions container item for '%s': %w", subscriptionID, err) } typedDoc, innerDoc, err := typedDocumentUnmarshal[arm.Subscription](response.Value) if err != nil { return nil, nil, fmt.Errorf("failed to unmarshal Subscriptions container item for '%s': %w", subscriptionID, err) } // Expose the "_ts" field for metics reporting. innerDoc.LastUpdated = typedDoc.CosmosTimestamp return typedDoc, innerDoc, nil } func (d *cosmosDBClient) GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*arm.Subscription, error) { _, innerDoc, err := d.getSubscriptionDoc(ctx, subscriptionID) return innerDoc, err } func (d *cosmosDBClient) CreateSubscriptionDoc(ctx context.Context, subscriptionID string, subscription *arm.Subscription) error { typedDoc := newTypedDocument(subscriptionID, azcorearm.SubscriptionResourceType) typedDoc.ID = strings.ToLower(subscriptionID) data, err := typedDocumentMarshal(typedDoc, subscription) if err != nil { return fmt.Errorf("failed to marshal Subscriptions container item for '%s': %w", subscriptionID, err) } _, err = d.resources.CreateItem(ctx, typedDoc.getPartitionKey(), data, nil) if err != nil { return fmt.Errorf("failed to create Subscriptions container item for '%s': %w", subscriptionID, err) } return nil } func (d *cosmosDBClient) UpdateSubscriptionDoc(ctx context.Context, subscriptionID string, callback func(*arm.Subscription) bool) (bool, error) { var err error options := &azcosmos.ItemOptions{} for try := 0; try < 5; try++ { var typedDoc *typedDocument var innerDoc *arm.Subscription var data []byte typedDoc, innerDoc, err = d.getSubscriptionDoc(ctx, subscriptionID) if err != nil { return false, err } if !callback(innerDoc) { return false, nil } data, err = typedDocumentMarshal(typedDoc, innerDoc) if err != nil { return false, fmt.Errorf("failed to marshal Subscriptions container item for '%s': %w", subscriptionID, err) } options.IfMatchEtag = &typedDoc.CosmosETag _, err = d.resources.ReplaceItem(ctx, typedDoc.getPartitionKey(), typedDoc.ID, data, options) if err == nil { return true, nil } var responseError *azcore.ResponseError err = fmt.Errorf("failed to replace Subscriptions container item for '%s': %w", subscriptionID, err) if !errors.As(err, &responseError) || responseError.StatusCode != http.StatusPreconditionFailed { return false, err } } return false, err } func (d *cosmosDBClient) ListAllSubscriptionDocs() DBClientIterator[arm.Subscription] { const query = "SELECT * FROM c WHERE STRINGEQUALS(c.resourceType, @resourceType, true)" opt := azcosmos.QueryOptions{ QueryParameters: []azcosmos.QueryParameter{ { Name: "@resourceType", Value: azcorearm.SubscriptionResourceType.String(), }, }, } // Empty partition key triggers a cross-partition query. pager := d.resources.NewQueryItemsPager(query, azcosmos.NewPartitionKey(), &opt) return newQueryItemsIterator[arm.Subscription](pager) } // NewCosmosDatabaseClient instantiates a generic Cosmos database client. func NewCosmosDatabaseClient(url string, dbName string, clientOptions azcore.ClientOptions) (*azcosmos.DatabaseClient, error) { credential, err := azidentity.NewDefaultAzureCredential( &azidentity.DefaultAzureCredentialOptions{ ClientOptions: clientOptions, }) if err != nil { return nil, err } client, err := azcosmos.NewClient( url, credential, &azcosmos.ClientOptions{ ClientOptions: clientOptions, }) if err != nil { return nil, err } return client.NewDatabase(dbName) }