in Microsoft.Azure.Cosmos/src/Routing/AvailabilityStrategy/CrossRegionHedgingAvailabilityStrategy.cs [116:253]
internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
Func<RequestMessage, CancellationToken, Task<ResponseMessage>> sender,
CosmosClient client,
RequestMessage request,
CancellationToken cancellationToken)
{
if (!this.ShouldHedge(request, client)
|| client.DocumentClient.GlobalEndpointManager.ReadEndpoints.Count == 1)
{
return await sender(request, cancellationToken);
}
ITrace trace = request.Trace;
using (CancellationTokenSource cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
using (CloneableStream clonedBody = (CloneableStream)(request.Content == null
? null
: await StreamExtension.AsClonableStreamAsync(request.Content)))
{
IReadOnlyCollection<string> hedgeRegions = client.DocumentClient.GlobalEndpointManager
.GetApplicableRegions(
request.RequestOptions?.ExcludeRegions,
OperationTypeExtensions.IsReadOperation(request.OperationType));
List<Task> requestTasks = new List<Task>(hedgeRegions.Count + 1);
Task<HedgingResponse> primaryRequest = null;
HedgingResponse hedgeResponse = null;
//Send out hedged requests
for (int requestNumber = 0; requestNumber < hedgeRegions.Count; requestNumber++)
{
TimeSpan awaitTime = requestNumber == 0 ? this.Threshold : this.ThresholdStep;
using (CancellationTokenSource timerTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
CancellationToken timerToken = timerTokenSource.Token;
using (Task hedgeTimer = Task.Delay(awaitTime, timerToken))
{
if (requestNumber == 0)
{
primaryRequest = this.RequestSenderAndResultCheckAsync(
sender,
request,
hedgeRegions.ElementAt(requestNumber),
cancellationToken,
cancellationTokenSource,
trace);
requestTasks.Add(primaryRequest);
}
else
{
Task<HedgingResponse> requestTask = this.CloneAndSendAsync(
sender: sender,
request: request,
clonedBody: clonedBody,
hedgeRegions: hedgeRegions,
requestNumber: requestNumber,
trace: trace,
cancellationToken: cancellationToken,
cancellationTokenSource: cancellationTokenSource);
requestTasks.Add(requestTask);
}
requestTasks.Add(hedgeTimer);
Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
if (completedTask == hedgeTimer)
{
continue;
}
timerTokenSource.Cancel();
requestTasks.Remove(hedgeTimer);
if (completedTask.IsFaulted)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
}
hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient)
{
cancellationTokenSource.Cancel();
//Take is not inclusive, so we need to add 1 to the request number which starts at 0
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions.Take(requestNumber + 1));
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
ResponseRegion,
hedgeResponse.ResponseRegion);
return hedgeResponse.ResponseMessage;
}
}
}
}
//Wait for a good response from the hedged requests/primary request
Exception lastException = null;
while (requestTasks.Any())
{
Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
if (completedTask.IsFaulted)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
lastException = innerExceptions.InnerExceptions.FirstOrDefault();
}
hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient || requestTasks.Count == 0)
{
cancellationTokenSource.Cancel();
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions);
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
ResponseRegion,
hedgeResponse.ResponseRegion);
return hedgeResponse.ResponseMessage;
}
}
if (lastException != null)
{
throw lastException;
}
Debug.Assert(hedgeResponse != null);
return hedgeResponse.ResponseMessage;
}
}
}