in Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs [515:749]
private static async Task HighConcurrencyPrefetchInParallelAsync(
IEnumerable<IPrefetcher> prefetchers,
int maxConcurrency,
ITrace trace,
ParallelPrefetchTestConfig config,
CancellationToken cancellationToken)
{
IPrefetcher[] currentBatch = null;
// this ends up holding a sort of linked list where
// each entry is actually a Task until the very last one
// which is an object[]
//
// as soon as a null is encountered, either where a Task or
// an object[] is expected, the linked list is done
object[] runningTasks = null;
try
{
using (ITrace prefetchTrace = CommonStartTrace(trace))
{
config?.SetInnerTrace(prefetchTrace);
using (IEnumerator<IPrefetcher> enumerator = prefetchers.GetEnumerator())
{
if (!enumerator.MoveNext())
{
// no prefetchers at all
return;
}
IPrefetcher first = enumerator.Current;
if (!enumerator.MoveNext())
{
// special case: a single prefetcher... just await it, and skip all the heavy work
config?.TaskStarted();
config?.TaskAwaited();
await first.PrefetchAsync(prefetchTrace, cancellationToken);
return;
}
// need to actually do things to start prefetching in parallel
// so grab some state and stash the first two prefetchers off
currentBatch = RentArray<IPrefetcher>(config, BatchLimit, clear: false);
currentBatch[0] = first;
currentBatch[1] = enumerator.Current;
// we need this all null because we use null as a stopping condition later
runningTasks = RentArray<object>(config, BatchLimit, clear: true);
CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken);
// what we do here is buffer up to BatchLimit IPrefetchers to start
// and then... start them all
//
// we stagger this so we quickly get a bunch of tasks started without spending too
// much time pre-loading everything
// grab our first bunch of prefetchers outside of the lock
//
// we know that maxConcurrency > BatchLimit, so can just pass it as our cutoff here
int bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 2, BatchLimit, enumerator);
int nextChunkIndex = 0;
object[] currentChunk = runningTasks;
int remainingConcurrency = maxConcurrency;
// if we encounter any error, we remember it
// but as soon as we start a single task we've got
// to see most of this code through so we observe them
ExceptionDispatchInfo capturedException = null;
while (true)
{
// start and store the last set of Tasks we got from FillPrefetcherBuffer
for (int toStartIndex = 0; toStartIndex < bufferedPrefetchers; toStartIndex++)
{
IPrefetcher prefetcher = currentBatch[toStartIndex];
Task startedTask = CommonStartTaskAsync(config, commonState, prefetcher);
currentChunk[nextChunkIndex] = startedTask;
nextChunkIndex++;
// check if we need a new slab to store tasks
if (nextChunkIndex == currentChunk.Length - 1)
{
// we need this all null because we use null as a stopping condition later
object[] newChunk = RentArray<object>(config, BatchLimit, clear: true);
currentChunk[currentChunk.Length - 1] = newChunk;
currentChunk = newChunk;
nextChunkIndex = 0;
}
}
remainingConcurrency -= bufferedPrefetchers;
// check to see if we've started all the concurrent Tasks we can
if (remainingConcurrency == 0)
{
break;
}
int nextBatchSizeLimit = remainingConcurrency < BatchLimit ? remainingConcurrency : BatchLimit;
// if one of the previously started Tasks exhausted the enumerator
// we're done, even if we still have space
if (commonState.FinishedEnumerating)
{
break;
}
// now that Tasks have started, we MUST synchronize access to
// the enumerator
lock (commonState)
{
// the answer might have changed, so we double-check
// this once we've got the lock
if (commonState.FinishedEnumerating)
{
break;
}
// grab the next set of prefetchers to start
try
{
bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 0, nextBatchSizeLimit, enumerator);
}
catch (Exception exc)
{
// this can get raised if the enumerator faults
//
// in this case we might have some tasks started, and so we need to _stop_ starting new tasks but
// still move on to observing everything we've already started
commonState.SetFinishedEnumerating();
capturedException = ExceptionDispatchInfo.Capture(exc);
break;
}
}
// if we got nothing back, we can break right here
if (bufferedPrefetchers == 0)
{
break;
}
}
// hand the prefetch array back, we're done with it
ReturnRentedArray(config, currentBatch, BatchLimit);
currentBatch = null;
// now wait for all the tasks to complete
//
// we walk through all of them, even if we encounter an error
// because we need to walk the whole linked-list and this is
// simpler than an explicit error code path
int toAwaitIndex = 0;
while (runningTasks != null)
{
Task toAwait = (Task)runningTasks[toAwaitIndex];
// if we see a null, we're done
if (toAwait == null)
{
// hand the last of the arrays back
ReturnRentedArray(config, runningTasks, toAwaitIndex);
runningTasks = null;
break;
}
try
{
config?.TaskAwaited();
await toAwait;
}
catch (Exception ex)
{
if (capturedException == null)
{
// if we encountered some exception, tell the remaining tasks to bail
// the next time they check commonState
commonState.SetFinishedEnumerating();
// save the exception so we can rethrow it later
capturedException = ExceptionDispatchInfo.Capture(ex);
}
}
// advance, moving to the next chunk if we've hit that limit
toAwaitIndex++;
if (toAwaitIndex == runningTasks.Length - 1)
{
object[] oldChunk = runningTasks;
runningTasks = (object[])runningTasks[runningTasks.Length - 1];
toAwaitIndex = 0;
// we're done with this, let some other caller reuse it immediately
ReturnRentedArray(config, oldChunk, oldChunk.Length);
}
}
// fault, if any task failed, after we've finished cleaning up
capturedException?.Throw();
}
}
}
finally
{
// cleanup if something went wrong while these were still rented
//
// this can basically only happen if the enumerator itself faults
// which is unlikely, but far from impossible
ReturnRentedArray(config, currentBatch, BatchLimit);
while (runningTasks != null)
{
object[] oldChunk = runningTasks;
runningTasks = (object[])runningTasks[runningTasks.Length - 1];
ReturnRentedArray(config, oldChunk, oldChunk.Length);
}
}
}