Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs (930 lines of code) (raw):

//------------------------------------------------------------ // Copyright (c) Microsoft Corporation. All rights reserved. //------------------------------------------------------------ namespace Microsoft.Azure.Cosmos.Routing { using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Common; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Query.Core.Monads; using Microsoft.Azure.Cosmos.Tracing; using Microsoft.Azure.Cosmos.Tracing.TraceData; using Microsoft.Azure.Documents; using Microsoft.Azure.Documents.Client; using Microsoft.Azure.Documents.Collections; using Microsoft.Azure.Documents.Rntbd; using Microsoft.Azure.Documents.Routing; internal class GatewayAddressCache : IAddressCache, IDisposable { private const string protocolFilterFormat = "{0} eq {1}"; private const string AddressResolutionBatchSize = "AddressResolutionBatchSize"; private const int DefaultBatchSize = 50; // This warmup cache and connection timeout is meant to mimic an indefinite timeframe till which // a delay task will run, until a cancellation token is requested to cancel the task. The default // value for this timeout is 45 minutes at the moment. private static readonly TimeSpan WarmupCacheAndOpenConnectionTimeout = TimeSpan.FromMinutes(45); private readonly Uri serviceEndpoint; private readonly Uri addressEndpoint; private readonly AsyncCacheNonBlocking<PartitionKeyRangeIdentity, PartitionAddressInformation> serverPartitionAddressCache; private readonly ConcurrentDictionary<PartitionKeyRangeIdentity, DateTime> suboptimalServerPartitionTimestamps; private readonly ConcurrentDictionary<ServerKey, HashSet<PartitionKeyRangeIdentity>> serverPartitionAddressToPkRangeIdMap; private readonly IServiceConfigurationReader serviceConfigReader; private readonly long suboptimalPartitionForceRefreshIntervalInSeconds; private readonly Protocol protocol; private readonly string protocolFilter; private readonly ICosmosAuthorizationTokenProvider tokenProvider; private readonly bool enableTcpConnectionEndpointRediscovery; private readonly SemaphoreSlim semaphore; private readonly CosmosHttpClient httpClient; private readonly bool isReplicaAddressValidationEnabled; private readonly IConnectionStateListener connectionStateListener; private Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> masterPartitionAddressCache; private DateTime suboptimalMasterPartitionTimestamp; private bool disposedValue; private bool validateUnknownReplicas; private IOpenConnectionsHandler openConnectionsHandler; public GatewayAddressCache( Uri serviceEndpoint, Protocol protocol, ICosmosAuthorizationTokenProvider tokenProvider, IServiceConfigurationReader serviceConfigReader, CosmosHttpClient httpClient, IOpenConnectionsHandler openConnectionsHandler, IConnectionStateListener connectionStateListener, long suboptimalPartitionForceRefreshIntervalInSeconds = 600, bool enableTcpConnectionEndpointRediscovery = false, bool replicaAddressValidationEnabled = false, bool enableAsyncCacheExceptionNoSharing = true) { this.addressEndpoint = new Uri(serviceEndpoint + "/" + Paths.AddressPathSegment); this.protocol = protocol; this.tokenProvider = tokenProvider; this.serviceEndpoint = serviceEndpoint; this.serviceConfigReader = serviceConfigReader; this.serverPartitionAddressCache = new AsyncCacheNonBlocking<PartitionKeyRangeIdentity, PartitionAddressInformation>(enableAsyncCacheExceptionNoSharing); this.suboptimalServerPartitionTimestamps = new ConcurrentDictionary<PartitionKeyRangeIdentity, DateTime>(); this.serverPartitionAddressToPkRangeIdMap = new ConcurrentDictionary<ServerKey, HashSet<PartitionKeyRangeIdentity>>(); this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue; this.enableTcpConnectionEndpointRediscovery = enableTcpConnectionEndpointRediscovery; this.connectionStateListener = connectionStateListener; this.suboptimalPartitionForceRefreshIntervalInSeconds = suboptimalPartitionForceRefreshIntervalInSeconds; this.httpClient = httpClient; this.protocolFilter = string.Format(CultureInfo.InvariantCulture, GatewayAddressCache.protocolFilterFormat, Constants.Properties.Protocol, GatewayAddressCache.ProtocolString(this.protocol)); this.semaphore = new SemaphoreSlim(1, 1); this.openConnectionsHandler = openConnectionsHandler; this.isReplicaAddressValidationEnabled = replicaAddressValidationEnabled; this.validateUnknownReplicas = false; } public Uri ServiceEndpoint => this.serviceEndpoint; /// <summary> /// Gets the address information from the gateway and sets them into the async non blocking cache for later lookup. /// Additionally attempts to establish Rntbd connections to the backend replicas based on `shouldOpenRntbdChannels` /// boolean flag. /// </summary> /// <param name="databaseName">A string containing the database name.</param> /// <param name="collection">An instance of <see cref="ContainerProperties"/> containing the collection properties.</param> /// <param name="partitionKeyRangeIdentities">A read only list containing the partition key range identities.</param> /// <param name="shouldOpenRntbdChannels">A boolean flag indicating whether Rntbd connections are required to be established /// to the backend replica nodes. For cosmos client initialization and cache warmups, the Rntbd connection are needed to be /// openned deterministically to the backend replicas to reduce latency, thus the <paramref name="shouldOpenRntbdChannels"/> /// should be set to `true` during cosmos client initialization and cache warmups. The OpenAsync flow from DocumentClient /// doesn't require the connections to be opened deterministically thus should set the parameter to `false`.</param> /// <param name="cancellationToken">An instance of <see cref="CancellationToken"/>.</param> public async Task OpenConnectionsAsync( string databaseName, ContainerProperties collection, IReadOnlyList<PartitionKeyRangeIdentity> partitionKeyRangeIdentities, bool shouldOpenRntbdChannels, CancellationToken cancellationToken) { List<Task> tasks = new (); int batchSize = GatewayAddressCache.DefaultBatchSize; // By design, the Unknown replicas are validated only when the following two conditions meet: // 1) The CosmosClient is initiated using the CreateAndInitializaAsync() flow. // 2) The advanced replica selection feature enabled. if (shouldOpenRntbdChannels) { this.validateUnknownReplicas = true; } #if !(NETSTANDARD15 || NETSTANDARD16) #if NETSTANDARD20 // GetEntryAssembly returns null when loaded from native netstandard2.0 if (System.Reflection.Assembly.GetEntryAssembly() != null) { #endif if (int.TryParse(System.Configuration.ConfigurationManager.AppSettings[GatewayAddressCache.AddressResolutionBatchSize], out int userSpecifiedBatchSize)) { batchSize = userSpecifiedBatchSize; } #if NETSTANDARD20 } #endif #endif string collectionAltLink = string.Format( CultureInfo.InvariantCulture, "{0}/{1}/{2}/{3}", Paths.DatabasesPathSegment, Uri.EscapeUriString(databaseName), Paths.CollectionsPathSegment, Uri.EscapeUriString(collection.Id)); using (DocumentServiceRequest request = DocumentServiceRequest.CreateFromName( OperationType.Read, collectionAltLink, ResourceType.Collection, AuthorizationTokenType.PrimaryMasterKey)) { for (int i = 0; i < partitionKeyRangeIdentities.Count; i += batchSize) { tasks.Add( this.WarmupCachesAndOpenConnectionsAsync( request: request, collectionRid: collection.ResourceId, partitionKeyRangeIds: partitionKeyRangeIdentities.Skip(i).Take(batchSize).Select(range => range.PartitionKeyRangeId), containerProperties: collection, shouldOpenRntbdChannels: shouldOpenRntbdChannels, cancellationToken: cancellationToken)); } } using CancellationTokenSource linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); // The `timeoutTask` is a background task which adds a delay for a period of WarmupCacheAndOpenConnectionTimeout. The task will // be cancelled either by - a) when `linkedTokenSource` expires, which means the original `cancellationToken` expires or // b) the the `linkedTokenSource.Cancel()` is called. Task timeoutTask = Task.Delay(GatewayAddressCache.WarmupCacheAndOpenConnectionTimeout, linkedTokenSource.Token); Task resultTask = await Task.WhenAny(Task.WhenAll(tasks), timeoutTask); if (resultTask == timeoutTask) { // Operation has been cancelled. DefaultTrace.TraceWarning("The open connection task was cancelled because the cancellation token was expired. '{0}'", System.Diagnostics.Trace.CorrelationManager.ActivityId); } else { linkedTokenSource.Cancel(); } } /// <inheritdoc/> public void SetOpenConnectionsHandler(IOpenConnectionsHandler openConnectionsHandler) { this.openConnectionsHandler = openConnectionsHandler; } /// <inheritdoc/> public async Task<PartitionAddressInformation> TryGetAddressesAsync( DocumentServiceRequest request, PartitionKeyRangeIdentity partitionKeyRangeIdentity, ServiceIdentity serviceIdentity, bool forceRefreshPartitionAddresses, CancellationToken cancellationToken) { if (request == null) { throw new ArgumentNullException(nameof(request)); } if (partitionKeyRangeIdentity == null) { throw new ArgumentNullException(nameof(partitionKeyRangeIdentity)); } try { if (partitionKeyRangeIdentity.PartitionKeyRangeId == PartitionKeyRange.MasterPartitionKeyRangeId) { return (await this.ResolveMasterAsync(request, forceRefreshPartitionAddresses)).Item2; } if (this.suboptimalServerPartitionTimestamps.TryGetValue(partitionKeyRangeIdentity, out DateTime suboptimalServerPartitionTimestamp)) { bool forceRefreshDueToSuboptimalPartitionReplicaSet = DateTime.UtcNow.Subtract(suboptimalServerPartitionTimestamp) > TimeSpan.FromSeconds(this.suboptimalPartitionForceRefreshIntervalInSeconds); if (forceRefreshDueToSuboptimalPartitionReplicaSet && this.suboptimalServerPartitionTimestamps.TryUpdate(partitionKeyRangeIdentity, DateTime.MaxValue, suboptimalServerPartitionTimestamp)) { forceRefreshPartitionAddresses = true; } } PartitionAddressInformation addresses; PartitionAddressInformation staleAddressInfo = null; if (forceRefreshPartitionAddresses || request.ForceCollectionRoutingMapRefresh) { addresses = await this.serverPartitionAddressCache.GetAsync( key: partitionKeyRangeIdentity, singleValueInitFunc: (currentCachedValue) => { staleAddressInfo = currentCachedValue; GatewayAddressCache.SetTransportAddressUrisToUnhealthy( currentCachedValue, request?.RequestContext?.FailedEndpoints); return this.GetAddressesForRangeIdAsync( request, cachedAddresses: currentCachedValue, partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: forceRefreshPartitionAddresses); }, forceRefresh: (currentCachedValue) => { int cachedHashCode = request?.RequestContext?.LastPartitionAddressInformationHashCode ?? 0; if (cachedHashCode == 0) { return true; } // The cached value is different then the previous access hash then assume // another request already updated the cache since there is a new value in the cache return currentCachedValue.GetHashCode() == cachedHashCode; }); if (staleAddressInfo != null) { GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, staleAddressInfo, addresses); } this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out DateTime ignoreDateTime); } else { addresses = await this.serverPartitionAddressCache.GetAsync( key: partitionKeyRangeIdentity, singleValueInitFunc: (_) => this.GetAddressesForRangeIdAsync( request, cachedAddresses: null, partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: false), forceRefresh: (_) => false); } // Always save the hash code. This is used to determine if another request already updated the cache. // This helps reduce latency by avoiding uncessary cache refreshes. if (request?.RequestContext != null) { request.RequestContext.LastPartitionAddressInformationHashCode = addresses.GetHashCode(); } int targetReplicaSetSize = this.serviceConfigReader.UserReplicationPolicy.MaxReplicaSetSize; if (addresses.AllAddresses.Count() < targetReplicaSetSize) { this.suboptimalServerPartitionTimestamps.TryAdd(partitionKeyRangeIdentity, DateTime.UtcNow); } // Refresh the cache on-demand, if there were some address that remained as unhealthy long enough (more than 1 minute) // and need to revalidate its status. The reason it is not dependent on 410 to force refresh the addresses, is being: // When an address is marked as unhealthy, then the address enumerator will deprioritize it and move it back to the // end of the transport uris list. Therefore, it could happen that no request will land on the unhealthy address for // an extended period of time therefore, the chances of 410 (Gone Exception) to trigger the forceRefresh workflow may // not happen for that particular replica. if (addresses .Get(Protocol.Tcp) .ReplicaTransportAddressUris .Any(x => x.ShouldRefreshHealthStatus())) { bool slimAcquired = await this.semaphore.WaitAsync(0); try { if (slimAcquired) { this.serverPartitionAddressCache.Refresh( key: partitionKeyRangeIdentity, singleValueInitFunc: (currentCachedValue) => this.GetAddressesForRangeIdAsync( request, cachedAddresses: currentCachedValue, partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: true)); } else { DefaultTrace.TraceVerbose("Failed to refresh addresses in the background for the collection rid: {0}, partition key range id: {1}, because the semaphore is already acquired. '{2}'", partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, System.Diagnostics.Trace.CorrelationManager.ActivityId); } } finally { if (slimAcquired) { this.semaphore.Release(); } } } return addresses; } catch (DocumentClientException ex) { if ((ex.StatusCode == HttpStatusCode.NotFound) || (ex.StatusCode == HttpStatusCode.Gone && ex.GetSubStatus() == SubStatusCodes.PartitionKeyRangeGone)) { //remove from suboptimal cache in case the the collection+pKeyRangeId combo is gone. this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out _); return null; } throw; } catch (Exception) { if (forceRefreshPartitionAddresses) { this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out _); } throw; } } /// <summary> /// Gets the address information from the gateway using the partition key range ids, and warms up the async non blocking cache /// by inserting them as a key value pair for later lookup. Additionally attempts to establish Rntbd connections to the backend /// replicas based on `shouldOpenRntbdChannels` boolean flag. /// </summary> /// <param name="request">An instance of <see cref="DocumentServiceRequest"/> containing the request payload.</param> /// <param name="collectionRid">A string containing the collection ids.</param> /// <param name="partitionKeyRangeIds">An instance of <see cref="IEnumerable{T}"/> containing the list of partition key range ids.</param> /// <param name="containerProperties">An instance of <see cref="ContainerProperties"/> containing the collection properties.</param> /// <param name="shouldOpenRntbdChannels">A boolean flag indicating whether Rntbd connections are required to be established to the backend replica nodes.</param> /// <param name="cancellationToken">An instance of <see cref="CancellationToken"/>.</param> private async Task WarmupCachesAndOpenConnectionsAsync( DocumentServiceRequest request, string collectionRid, IEnumerable<string> partitionKeyRangeIds, ContainerProperties containerProperties, bool shouldOpenRntbdChannels, CancellationToken cancellationToken) { TryCatch<DocumentServiceResponse> documentServiceResponseWrapper = await this.GetAddressesAsync( request: request, collectionRid: collectionRid, partitionKeyRangeIds: partitionKeyRangeIds); if (documentServiceResponseWrapper.Failed) { return; } try { using (DocumentServiceResponse response = documentServiceResponseWrapper.Result) { FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>(); bool inNetworkRequest = this.IsInNetworkRequest(response); IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos = addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol) .GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal) .Select(group => this.ToPartitionAddressAndRange(containerProperties.ResourceId, @group.ToList(), inNetworkRequest)); List<Task> openConnectionTasks = new (); foreach (Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> addressInfo in addressInfos) { if (cancellationToken.IsCancellationRequested) { break; } this.serverPartitionAddressCache.Set( new PartitionKeyRangeIdentity(containerProperties.ResourceId, addressInfo.Item1.PartitionKeyRangeId), addressInfo.Item2); // The `shouldOpenRntbdChannels` boolean flag indicates whether the SDK should establish Rntbd connections to the // backend replica nodes. For the `CosmosClient.CreateAndInitializeAsync()` flow, the flag should be passed as // `true` so that the Rntbd connections to the backend replicas could be established deterministically. For any // other flow, the flag should be passed as `false`. if (this.openConnectionsHandler != null && shouldOpenRntbdChannels) { openConnectionTasks .Add(this.openConnectionsHandler .TryOpenRntbdChannelsAsync( addresses: addressInfo.Item2.Get(Protocol.Tcp)?.ReplicaTransportAddressUris)); } } await Task.WhenAll(openConnectionTasks); } } catch (Exception ex) { DefaultTrace.TraceWarning("Failed to warm-up caches and open connections for the server addresses: {0} with exception: {1}. '{2}'", collectionRid, ex.Message, System.Diagnostics.Trace.CorrelationManager.ActivityId); } } private static void SetTransportAddressUrisToUnhealthy( PartitionAddressInformation stalePartitionAddressInformation, Lazy<HashSet<TransportAddressUri>> failedEndpoints) { if (stalePartitionAddressInformation == null || failedEndpoints == null || !failedEndpoints.IsValueCreated) { return; } IReadOnlyList<TransportAddressUri> perProtocolPartitionAddressInformation = stalePartitionAddressInformation.Get(Protocol.Tcp)?.ReplicaTransportAddressUris; if (perProtocolPartitionAddressInformation == null) { return; } foreach (TransportAddressUri failed in perProtocolPartitionAddressInformation) { if (failedEndpoints.Value.Contains(failed)) { failed.SetUnhealthy(); } } } // Overloaded method, the previous Lazy<HashSet<TransportAddressUri>> will be removed in a future release // Once this is merged to master, we will cherry-pick the v3 master commit to OSS and create a new OSS release to use the OSS commit in the msdata PR to unblock the build failures from OSS. private static void SetTransportAddressUrisToUnhealthy( PartitionAddressInformation stalePartitionAddressInformation, Lazy<ConcurrentDictionary<TransportAddressUri, bool>> failedEndpoints) { if (stalePartitionAddressInformation == null || failedEndpoints == null || !failedEndpoints.IsValueCreated) { return; } IReadOnlyList<TransportAddressUri> perProtocolPartitionAddressInformation = stalePartitionAddressInformation.Get(Protocol.Tcp)?.ReplicaTransportAddressUris; if (perProtocolPartitionAddressInformation == null) { return; } foreach (TransportAddressUri failed in perProtocolPartitionAddressInformation) { if (failedEndpoints.Value.ContainsKey(failed)) { failed.SetUnhealthy(); } } } private static void LogPartitionCacheRefresh( IClientSideRequestStatistics clientSideRequestStatistics, PartitionAddressInformation old, PartitionAddressInformation updated) { if (clientSideRequestStatistics is ClientSideRequestStatisticsTraceDatum traceDatum) { traceDatum.RecordAddressCachRefreshContent(old, updated); } } /// <summary> /// Marks the <see cref="TransportAddressUri"/> to Unhealthy that matches with the faulted /// server key. /// </summary> /// <param name="serverKey">An instance of <see cref="ServerKey"/> that contains the host and /// port of the backend replica.</param> public async Task MarkAddressesToUnhealthyAsync( ServerKey serverKey) { if (this.disposedValue) { // Will enable Listener to un-register in-case of un-graceful dispose // <see cref="ConnectionStateMuxListener.NotifyAsync(ServerKey, ConcurrentDictionary{Func{ServerKey, Task}, object})"/> throw new ObjectDisposedException(nameof(GatewayAddressCache)); } if (serverKey == null) { throw new ArgumentNullException(nameof(serverKey)); } if (this.serverPartitionAddressToPkRangeIdMap.TryGetValue(serverKey, out HashSet<PartitionKeyRangeIdentity> pkRangeIds)) { PartitionKeyRangeIdentity[] pkRangeIdsCopy; lock (pkRangeIds) { pkRangeIdsCopy = pkRangeIds.ToArray(); } foreach (PartitionKeyRangeIdentity pkRangeId in pkRangeIdsCopy) { // The forceRefresh flag is set to true for the callback delegate is because, if the GetAsync() from the async // non-blocking cache fails to look up the pkRangeId, then there are some inconsistency present in the cache, and it is // more safe to do a force refresh to fetch the addresses from the gateway, instead of fetching it from the cache itself. // Please note that, the chances of encountering such scenario is highly unlikely. PartitionAddressInformation addressInfo = await this.serverPartitionAddressCache.GetAsync( key: pkRangeId, singleValueInitFunc: (_) => this.GetAddressesForRangeIdAsync( null, cachedAddresses: null, pkRangeId.CollectionRid, pkRangeId.PartitionKeyRangeId, forceRefresh: true), forceRefresh: (_) => false); IReadOnlyList<TransportAddressUri> transportAddresses = addressInfo.Get(Protocol.Tcp)?.ReplicaTransportAddressUris; foreach (TransportAddressUri address in from TransportAddressUri transportAddress in transportAddresses where serverKey.Equals(transportAddress.ReplicaServerKey) select transportAddress) { DefaultTrace.TraceInformation("Marking a backend replica to Unhealthy for collectionRid :{0}, pkRangeId: {1}, serviceEndpoint: {2}, transportAddress: {3}", pkRangeId.CollectionRid, pkRangeId.PartitionKeyRangeId, this.serviceEndpoint, address.ToString()); address.SetUnhealthy(); } // Update the health status this.CaptureTransportAddressUriHealthStates(addressInfo, transportAddresses); } } } private async Task<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> ResolveMasterAsync(DocumentServiceRequest request, bool forceRefresh) { Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> masterAddressAndRange = this.masterPartitionAddressCache; int targetReplicaSetSize = this.serviceConfigReader.SystemReplicationPolicy.MaxReplicaSetSize; forceRefresh = forceRefresh || (masterAddressAndRange != null && masterAddressAndRange.Item2.AllAddresses.Count() < targetReplicaSetSize && DateTime.UtcNow.Subtract(this.suboptimalMasterPartitionTimestamp) > TimeSpan.FromSeconds(this.suboptimalPartitionForceRefreshIntervalInSeconds)); if (forceRefresh || request.ForceCollectionRoutingMapRefresh || this.masterPartitionAddressCache == null) { string entryUrl = PathsHelper.GeneratePath( ResourceType.Database, string.Empty, true); try { using (DocumentServiceResponse response = await this.GetMasterAddressesViaGatewayAsync( request, ResourceType.Database, null, entryUrl, forceRefresh, false)) { FeedResource<Address> masterAddresses = response.GetResource<FeedResource<Address>>(); bool inNetworkRequest = this.IsInNetworkRequest(response); masterAddressAndRange = this.ToPartitionAddressAndRange(string.Empty, masterAddresses.ToList(), inNetworkRequest); this.masterPartitionAddressCache = masterAddressAndRange; this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue; } } catch (Exception) { this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue; throw; } } if (masterAddressAndRange.Item2.AllAddresses.Count() < targetReplicaSetSize && this.suboptimalMasterPartitionTimestamp.Equals(DateTime.MaxValue)) { this.suboptimalMasterPartitionTimestamp = DateTime.UtcNow; } return masterAddressAndRange; } private async Task<PartitionAddressInformation> GetAddressesForRangeIdAsync( DocumentServiceRequest request, PartitionAddressInformation cachedAddresses, string collectionRid, string partitionKeyRangeId, bool forceRefresh) { using (DocumentServiceResponse response = await this.GetServerAddressesViaGatewayAsync(request, collectionRid, new[] { partitionKeyRangeId }, forceRefresh)) { FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>(); bool inNetworkRequest = this.IsInNetworkRequest(response); IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos = addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol) .GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal) .Select(group => this.ToPartitionAddressAndRange(collectionRid, @group.ToList(), inNetworkRequest)); Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> result = addressInfos.SingleOrDefault( addressInfo => StringComparer.Ordinal.Equals(addressInfo.Item1.PartitionKeyRangeId, partitionKeyRangeId)); if (result == null) { string errorMessage = string.Format( CultureInfo.InvariantCulture, RMResources.PartitionKeyRangeNotFound, partitionKeyRangeId, collectionRid); throw new PartitionKeyRangeGoneException(errorMessage) { ResourceAddress = collectionRid }; } if (this.isReplicaAddressValidationEnabled) { // The purpose of this step is to merge the new transport addresses with the old one. What this means is - // 1. If a newly returned address from gateway is already a part of the cache, then restore the health state // of the new address with that of the cached one. // 2. If a newly returned address from gateway doesn't exist in the cache, then keep using the new address // with `Unknown` (initial) status. PartitionAddressInformation mergedAddresses = GatewayAddressCache.MergeAddresses(result.Item2, cachedAddresses); IReadOnlyList<TransportAddressUri> transportAddressUris = mergedAddresses.Get(Protocol.Tcp)?.ReplicaTransportAddressUris; // If cachedAddresses are null, that would mean that the returned address from gateway would remain in Unknown // status and there is no cached state that could transition them into Unhealthy. if (cachedAddresses != null) { foreach (TransportAddressUri address in transportAddressUris) { // The main purpose for this step is to move address health status from Unhealthy to UnhealthyPending. address.SetRefreshedIfUnhealthy(); } } this.ValidateReplicaAddresses(transportAddressUris); this.CaptureTransportAddressUriHealthStates( partitionAddressInformation: mergedAddresses, transportAddressUris: transportAddressUris); return mergedAddresses; } this.CaptureTransportAddressUriHealthStates( partitionAddressInformation: result.Item2, transportAddressUris: result.Item2.Get(Protocol.Tcp)?.ReplicaTransportAddressUris); return result.Item2; } } private async Task<DocumentServiceResponse> GetMasterAddressesViaGatewayAsync( DocumentServiceRequest request, ResourceType resourceType, string resourceAddress, string entryUrl, bool forceRefresh, bool useMasterCollectionResolver) { INameValueCollection addressQuery = new RequestNameValueCollection { { HttpConstants.QueryStrings.Url, HttpUtility.UrlEncode(entryUrl) } }; INameValueCollection headers = new RequestNameValueCollection(); if (forceRefresh) { headers.Set(HttpConstants.HttpHeaders.ForceRefresh, bool.TrueString); } if (useMasterCollectionResolver) { headers.Set(HttpConstants.HttpHeaders.UseMasterCollectionResolver, bool.TrueString); } if (request.ForceCollectionRoutingMapRefresh) { headers.Set(HttpConstants.HttpHeaders.ForceCollectionRoutingMapRefresh, bool.TrueString); } addressQuery.Add(HttpConstants.QueryStrings.Filter, this.protocolFilter); string resourceTypeToSign = PathsHelper.GetResourcePath(resourceType); headers.Set(HttpConstants.HttpHeaders.XDate, Rfc1123DateTimeCache.UtcNow()); using (ITrace trace = Trace.GetRootTrace(nameof(GetMasterAddressesViaGatewayAsync), TraceComponent.Authorization, TraceLevel.Info)) { string token = await this.tokenProvider.GetUserAuthorizationTokenAsync( resourceAddress, resourceTypeToSign, HttpConstants.HttpMethods.Get, headers, AuthorizationTokenType.PrimaryMasterKey, trace); headers.Set(HttpConstants.HttpHeaders.Authorization, token); Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery)); string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint); if (this.httpClient.IsFaultInjectionClient) { using (DocumentServiceRequest faultInjectionRequest = DocumentServiceRequest.Create( operationType: OperationType.Read, resourceType: ResourceType.Address, authorizationTokenType: AuthorizationTokenType.PrimaryMasterKey)) { faultInjectionRequest.RequestContext = request.RequestContext; using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync( uri: targetEndpoint, additionalHeaders: headers, resourceType: resourceType, timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout, clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics, cancellationToken: default, documentServiceRequest: faultInjectionRequest)) { DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage); GatewayAddressCache.LogAddressResolutionEnd(request, identifier); return documentServiceResponse; } } } using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync( uri: targetEndpoint, additionalHeaders: headers, resourceType: resourceType, timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout, clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics, cancellationToken: default)) { DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage); GatewayAddressCache.LogAddressResolutionEnd(request, identifier); return documentServiceResponse; } } } private async Task<DocumentServiceResponse> GetServerAddressesViaGatewayAsync( DocumentServiceRequest request, string collectionRid, IEnumerable<string> partitionKeyRangeIds, bool forceRefresh) { string entryUrl = PathsHelper.GeneratePath(ResourceType.Document, collectionRid, true); INameValueCollection addressQuery = new RequestNameValueCollection { { HttpConstants.QueryStrings.Url, HttpUtility.UrlEncode(entryUrl) } }; INameValueCollection headers = new RequestNameValueCollection(); if (forceRefresh) { headers.Set(HttpConstants.HttpHeaders.ForceRefresh, bool.TrueString); } if (request != null && request.ForceCollectionRoutingMapRefresh) { headers.Set(HttpConstants.HttpHeaders.ForceCollectionRoutingMapRefresh, bool.TrueString); } addressQuery.Add(HttpConstants.QueryStrings.Filter, this.protocolFilter); addressQuery.Add(HttpConstants.QueryStrings.PartitionKeyRangeIds, string.Join(",", partitionKeyRangeIds)); string resourceTypeToSign = PathsHelper.GetResourcePath(ResourceType.Document); headers.Set(HttpConstants.HttpHeaders.XDate, Rfc1123DateTimeCache.UtcNow()); string token = null; using (ITrace trace = Trace.GetRootTrace(nameof(GetMasterAddressesViaGatewayAsync), TraceComponent.Authorization, TraceLevel.Info)) { try { token = await this.tokenProvider.GetUserAuthorizationTokenAsync( collectionRid, resourceTypeToSign, HttpConstants.HttpMethods.Get, headers, AuthorizationTokenType.PrimaryMasterKey, trace); } catch (UnauthorizedException) { } if (token == null && request != null && request.IsNameBased) { // User doesn't have rid based resource token. Maybe he has name based. string collectionAltLink = PathsHelper.GetCollectionPath(request.ResourceAddress); token = await this.tokenProvider.GetUserAuthorizationTokenAsync( collectionAltLink, resourceTypeToSign, HttpConstants.HttpMethods.Get, headers, AuthorizationTokenType.PrimaryMasterKey, trace); } headers.Set(HttpConstants.HttpHeaders.Authorization, token); Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery)); string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint); if (this.httpClient.IsFaultInjectionClient) { using (DocumentServiceRequest faultInjectionRequest = DocumentServiceRequest.Create( operationType: OperationType.Read, resourceType: ResourceType.Address, authorizationTokenType: AuthorizationTokenType.PrimaryMasterKey)) { faultInjectionRequest.RequestContext = request.RequestContext; using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync( uri: targetEndpoint, additionalHeaders: headers, resourceType: ResourceType.Document, timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout, clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics, cancellationToken: default, documentServiceRequest: faultInjectionRequest)) { DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage); GatewayAddressCache.LogAddressResolutionEnd(request, identifier); return documentServiceResponse; } } } using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync( uri: targetEndpoint, additionalHeaders: headers, resourceType: ResourceType.Document, timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.InstanceShouldThrow503OnTimeout, clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics, cancellationToken: default)) { DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage); GatewayAddressCache.LogAddressResolutionEnd(request, identifier); return documentServiceResponse; } } } internal Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> ToPartitionAddressAndRange(string collectionRid, IList<Address> addresses, bool inNetworkRequest) { Address address = addresses.First(); IReadOnlyList<AddressInformation> addressInfosSorted = GatewayAddressCache.GetSortedAddressInformation(addresses); PartitionKeyRangeIdentity partitionKeyRangeIdentity = new PartitionKeyRangeIdentity(collectionRid, address.PartitionKeyRangeId); if (this.enableTcpConnectionEndpointRediscovery && partitionKeyRangeIdentity.PartitionKeyRangeId != PartitionKeyRange.MasterPartitionKeyRangeId) { // add serverKey-pkRangeIdentity mapping only for addresses retrieved from gateway foreach (AddressInformation addressInfo in addressInfosSorted) { DefaultTrace.TraceInformation("Added address to serverPartitionAddressToPkRangeIdMap, collectionRid :{0}, pkRangeId: {1}, address: {2}", partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, addressInfo.PhysicalUri); HashSet<PartitionKeyRangeIdentity> createdValue = null; ServerKey serverKey = new ServerKey(new Uri(addressInfo.PhysicalUri)); HashSet<PartitionKeyRangeIdentity> pkRangeIdSet = this.serverPartitionAddressToPkRangeIdMap.GetOrAdd( serverKey, (_) => { createdValue = new HashSet<PartitionKeyRangeIdentity>(); return createdValue; }); if (object.ReferenceEquals(pkRangeIdSet, createdValue)) { this.connectionStateListener.Register(serverKey, this.MarkAddressesToUnhealthyAsync); } lock (pkRangeIdSet) { pkRangeIdSet.Add(partitionKeyRangeIdentity); } } } return Tuple.Create( partitionKeyRangeIdentity, new PartitionAddressInformation(addressInfosSorted, inNetworkRequest)); } private static IReadOnlyList<AddressInformation> GetSortedAddressInformation(IList<Address> addresses) { AddressInformation[] addressInformationArray = new AddressInformation[addresses.Count]; for (int i = 0; i < addresses.Count; i++) { Address addr = addresses[i]; addressInformationArray[i] = new AddressInformation( isPrimary: addr.IsPrimary, physicalUri: addr.PhysicalUri, protocol: ProtocolFromString(addr.Protocol), isPublic: true); } Array.Sort(addressInformationArray); return addressInformationArray; } private bool IsInNetworkRequest(DocumentServiceResponse documentServiceResponse) { bool inNetworkRequest = false; string inNetworkHeader = documentServiceResponse.ResponseHeaders.Get(HttpConstants.HttpHeaders.LocalRegionRequest); if (!string.IsNullOrEmpty(inNetworkHeader)) { bool.TryParse(inNetworkHeader, out inNetworkRequest); } return inNetworkRequest; } private static string LogAddressResolutionStart(DocumentServiceRequest request, Uri targetEndpoint) { string identifier = null; if (request != null && request.RequestContext.ClientRequestStatistics != null) { identifier = request.RequestContext.ClientRequestStatistics.RecordAddressResolutionStart(targetEndpoint); } return identifier; } private static void LogAddressResolutionEnd(DocumentServiceRequest request, string identifier) { if (request != null && request.RequestContext.ClientRequestStatistics != null) { request.RequestContext.ClientRequestStatistics.RecordAddressResolutionEnd(identifier); } } private static Protocol ProtocolFromString(string protocol) { return protocol.ToLowerInvariant() switch { RuntimeConstants.Protocols.HTTPS => Protocol.Https, RuntimeConstants.Protocols.RNTBD => Protocol.Tcp, _ => throw new ArgumentOutOfRangeException("protocol"), }; } private static string ProtocolString(Protocol protocol) { return (int)protocol switch { (int)Protocol.Https => RuntimeConstants.Protocols.HTTPS, (int)Protocol.Tcp => RuntimeConstants.Protocols.RNTBD, _ => throw new ArgumentOutOfRangeException("protocol"), }; } /// <summary> /// Utilizes the <see cref="TryCatch{TResult}"/> to get the server addresses. If an /// exception is thrown during the invocation, it handles it gracefully and returns /// a <see cref="TryCatch{TResult}"/> Task containing the exception. /// </summary> /// <param name="request">An instance of <see cref="DocumentServiceRequest"/> containing the request payload.</param> /// <param name="collectionRid">A string containing the collection ids.</param> /// <param name="partitionKeyRangeIds">An instance of <see cref="IEnumerable{T}"/> containing the list of partition key range ids.</param> /// <returns>A task of <see cref="TryCatch{TResult}"/> containing the result.</returns> private async Task<TryCatch<DocumentServiceResponse>> GetAddressesAsync( DocumentServiceRequest request, string collectionRid, IEnumerable<string> partitionKeyRangeIds) { try { return TryCatch<DocumentServiceResponse> .FromResult( await this.GetServerAddressesViaGatewayAsync( request: request, collectionRid: collectionRid, partitionKeyRangeIds: partitionKeyRangeIds, forceRefresh: false)); } catch (Exception ex) { DefaultTrace.TraceWarning("Failed to fetch the server addresses for: {0} with exception: {1}. '{2}'", collectionRid, ex.Message, System.Diagnostics.Trace.CorrelationManager.ActivityId); return TryCatch<DocumentServiceResponse>.FromException(ex); } } /// <summary> /// Validates the unknown or unhealthy-pending replicas by attempting to open the Rntbd connection. This operation /// will eventually marks the unknown or unhealthy-pending replicas to healthy, if the rntbd connection attempt made was /// successful or unhealthy otherwise. /// </summary> /// <param name="addresses">A read-only list of <see cref="TransportAddressUri"/> needs to be validated.</param> private void ValidateReplicaAddresses( IReadOnlyList<TransportAddressUri> addresses) { if (addresses == null) { throw new ArgumentNullException(nameof(addresses)); } IEnumerable<TransportAddressUri> addressesNeedToValidateStatus = this.GetAddressesNeededToValidateStatus( transportAddresses: addresses); if (addressesNeedToValidateStatus.Any()) { Task openConnectionsInBackgroundTask = Task.Run(async () => await this.openConnectionsHandler.TryOpenRntbdChannelsAsync( addresses: addressesNeedToValidateStatus)); } } /// <summary> /// Merge the new addresses returned from gateway service with that of the cached addresses. If the returned /// new addresses list contains some of the addresses, which are already cached, then reset the health state /// of the new address to that of the cached one. If the the new addresses doesn't contain any of the cached /// addresses, then keep using the health state of the new addresses, which should be `unknown`. /// </summary> /// <param name="newAddresses">A list of <see cref="PartitionAddressInformation"/> containing the latest /// addresses being returned from gateway.</param> /// <param name="cachedAddresses">A list of <see cref="PartitionAddressInformation"/> containing the cached /// addresses from the async non blocking cache.</param> /// <returns>A list of <see cref="PartitionAddressInformation"/> containing the merged addresses.</returns> private static PartitionAddressInformation MergeAddresses( PartitionAddressInformation newAddresses, PartitionAddressInformation cachedAddresses) { if (newAddresses == null) { throw new ArgumentNullException(nameof(newAddresses)); } if (cachedAddresses == null) { return newAddresses; } PerProtocolPartitionAddressInformation currentAddressInfo = newAddresses.Get(Protocol.Tcp); PerProtocolPartitionAddressInformation cachedAddressInfo = cachedAddresses.Get(Protocol.Tcp); Dictionary<string, TransportAddressUri> cachedAddressDict = new (); foreach (TransportAddressUri transportAddressUri in cachedAddressInfo.ReplicaTransportAddressUris) { cachedAddressDict[transportAddressUri.ToString()] = transportAddressUri; } foreach (TransportAddressUri transportAddressUri in currentAddressInfo.ReplicaTransportAddressUris) { if (cachedAddressDict.ContainsKey(transportAddressUri.ToString())) { TransportAddressUri cachedTransportAddressUri = cachedAddressDict[transportAddressUri.ToString()]; transportAddressUri.ResetHealthStatus( status: cachedTransportAddressUri.GetCurrentHealthState().GetHealthStatus(), lastUnknownTimestamp: cachedTransportAddressUri.GetCurrentHealthState().GetLastKnownTimestampByHealthStatus( healthStatus: TransportAddressHealthState.HealthStatus.Unknown), lastUnhealthyPendingTimestamp: cachedTransportAddressUri.GetCurrentHealthState().GetLastKnownTimestampByHealthStatus( healthStatus: TransportAddressHealthState.HealthStatus.UnhealthyPending), lastUnhealthyTimestamp: cachedTransportAddressUri.GetCurrentHealthState().GetLastKnownTimestampByHealthStatus( healthStatus: TransportAddressHealthState.HealthStatus.Unhealthy)); } } return newAddresses; } /// <summary> /// Returns a list of <see cref="TransportAddressUri"/> needed to validate their health status. Validating /// a uri is done by opening Rntbd connection to the backend replica, which is a costly operation by nature. Therefore /// vaidating both Unhealthy and Unknown replicas at the same time could impose a high CPU utilization. To avoid this /// situation, the RntbdOpenConnectionHandler has good concurrency control mechanism to open the connections gracefully. /// By default, this method only returns the Unhealthy replicas that requires to validate it's connectivity status. The /// Unknown replicas are validated only when the CosmosClient is initiated using the CreateAndInitializaAsync() flow. /// </summary> /// <param name="transportAddresses">A read only list of <see cref="TransportAddressUri"/>s.</param> /// <returns>A list of <see cref="TransportAddressUri"/> that needs to validate their status.</returns> private IEnumerable<TransportAddressUri> GetAddressesNeededToValidateStatus( IReadOnlyList<TransportAddressUri> transportAddresses) { return this.validateUnknownReplicas ? transportAddresses .Where(address => address .GetCurrentHealthState() .GetHealthStatus() is TransportAddressHealthState.HealthStatus.UnhealthyPending or TransportAddressHealthState.HealthStatus.Unknown) : transportAddresses .Where(address => address .GetCurrentHealthState() .GetHealthStatus() is TransportAddressHealthState.HealthStatus.UnhealthyPending); } /// <summary> /// The replica health status of the transport address uri will change eventually with the motonically increasing time. /// However, the purpose of this method is to capture the health status snapshot at this moment. /// </summary> /// <param name="partitionAddressInformation">An instance of <see cref="PartitionAddressInformation"/>.</param> /// <param name="transportAddressUris">A read-only list of <see cref="TransportAddressUri"/>.</param> private void CaptureTransportAddressUriHealthStates( PartitionAddressInformation partitionAddressInformation, IReadOnlyList<TransportAddressUri> transportAddressUris) { partitionAddressInformation .Get(Protocol.Tcp)? .SetTransportAddressUrisHealthState( replicaHealthStates: transportAddressUris.Select(x => x.GetCurrentHealthState().GetHealthStatusDiagnosticString()).ToList()); } protected virtual void Dispose(bool disposing) { if (this.disposedValue) { DefaultTrace.TraceInformation("GatewayAddressCache is already disposed {0}", this.GetHashCode()); return; } if (disposing) { // Unregister the server-key foreach (ServerKey serverKey in this.serverPartitionAddressToPkRangeIdMap.Keys) { this.connectionStateListener.UnRegister(serverKey, this.MarkAddressesToUnhealthyAsync); } this.serverPartitionAddressCache?.Dispose(); } this.disposedValue = true; } public void Dispose() { this.Dispose(disposing: true); } } }