private static async Task HighConcurrencyPrefetchInParallelAsync()

in Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs [515:749]


        private static async Task HighConcurrencyPrefetchInParallelAsync(
            IEnumerable<IPrefetcher> prefetchers,
            int maxConcurrency,
            ITrace trace,
            ParallelPrefetchTestConfig config,
            CancellationToken cancellationToken)
        {
            IPrefetcher[] currentBatch = null;

            // this ends up holding a sort of linked list where
            // each entry is actually a Task until the very last one
            // which is an object[]
            //
            // as soon as a null is encountered, either where a Task or
            // an object[] is expected, the linked list is done
            object[] runningTasks = null;

            try
            {
                using (ITrace prefetchTrace = CommonStartTrace(trace))
                {
                    config?.SetInnerTrace(prefetchTrace);

                    using (IEnumerator<IPrefetcher> enumerator = prefetchers.GetEnumerator())
                    {
                        if (!enumerator.MoveNext())
                        {
                            // no prefetchers at all
                            return;
                        }

                        IPrefetcher first = enumerator.Current;

                        if (!enumerator.MoveNext())
                        {
                            // special case: a single prefetcher... just await it, and skip all the heavy work
                            config?.TaskStarted();
                            config?.TaskAwaited();
                            await first.PrefetchAsync(prefetchTrace, cancellationToken);
                            return;
                        }

                        // need to actually do things to start prefetching in parallel
                        // so grab some state and stash the first two prefetchers off

                        currentBatch = RentArray<IPrefetcher>(config, BatchLimit, clear: false);
                        currentBatch[0] = first;
                        currentBatch[1] = enumerator.Current;

                        // we need this all null because we use null as a stopping condition later
                        runningTasks = RentArray<object>(config, BatchLimit, clear: true);

                        CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken);

                        // what we do here is buffer up to BatchLimit IPrefetchers to start
                        // and then... start them all
                        //
                        // we stagger this so we quickly get a bunch of tasks started without spending too
                        // much time pre-loading everything

                        // grab our first bunch of prefetchers outside of the lock
                        //
                        // we know that maxConcurrency > BatchLimit, so can just pass it as our cutoff here
                        int bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 2, BatchLimit, enumerator);

                        int nextChunkIndex = 0;
                        object[] currentChunk = runningTasks;

                        int remainingConcurrency = maxConcurrency;

                        // if we encounter any error, we remember it
                        // but as soon as we start a single task we've got
                        // to see most of this code through so we observe them
                        ExceptionDispatchInfo capturedException = null;

                        while (true)
                        {
                            // start and store the last set of Tasks we got from FillPrefetcherBuffer
                            for (int toStartIndex = 0; toStartIndex < bufferedPrefetchers; toStartIndex++)
                            {
                                IPrefetcher prefetcher = currentBatch[toStartIndex];
                                Task startedTask = CommonStartTaskAsync(config, commonState, prefetcher);

                                currentChunk[nextChunkIndex] = startedTask;
                                nextChunkIndex++;

                                // check if we need a new slab to store tasks
                                if (nextChunkIndex == currentChunk.Length - 1)
                                {
                                    // we need this all null because we use null as a stopping condition later
                                    object[] newChunk = RentArray<object>(config, BatchLimit, clear: true);

                                    currentChunk[currentChunk.Length - 1] = newChunk;

                                    currentChunk = newChunk;
                                    nextChunkIndex = 0;
                                }
                            }

                            remainingConcurrency -= bufferedPrefetchers;

                            // check to see if we've started all the concurrent Tasks we can
                            if (remainingConcurrency == 0)
                            {
                                break;
                            }

                            int nextBatchSizeLimit = remainingConcurrency < BatchLimit ? remainingConcurrency : BatchLimit;

                            // if one of the previously started Tasks exhausted the enumerator
                            // we're done, even if we still have space
                            if (commonState.FinishedEnumerating)
                            {
                                break;
                            }

                            // now that Tasks have started, we MUST synchronize access to 
                            // the enumerator
                            lock (commonState)
                            {
                                // the answer might have changed, so we double-check
                                // this once we've got the lock
                                if (commonState.FinishedEnumerating)
                                {
                                    break;
                                }

                                // grab the next set of prefetchers to start
                                try
                                {
                                    bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 0, nextBatchSizeLimit, enumerator);
                                }
                                catch (Exception exc)
                                {
                                    // this can get raised if the enumerator faults
                                    //
                                    // in this case we might have some tasks started, and so we need to _stop_ starting new tasks but
                                    // still move on to observing everything we've already started

                                    commonState.SetFinishedEnumerating();
                                    capturedException = ExceptionDispatchInfo.Capture(exc);

                                    break;
                                }
                            }

                            // if we got nothing back, we can break right here
                            if (bufferedPrefetchers == 0)
                            {
                                break;
                            }
                        }

                        // hand the prefetch array back, we're done with it
                        ReturnRentedArray(config, currentBatch, BatchLimit);
                        currentBatch = null;

                        // now wait for all the tasks to complete
                        //
                        // we walk through all of them, even if we encounter an error
                        // because we need to walk the whole linked-list and this is
                        // simpler than an explicit error code path

                        int toAwaitIndex = 0;
                        while (runningTasks != null)
                        {
                            Task toAwait = (Task)runningTasks[toAwaitIndex];

                            // if we see a null, we're done
                            if (toAwait == null)
                            {
                                // hand the last of the arrays back
                                ReturnRentedArray(config, runningTasks, toAwaitIndex);
                                runningTasks = null;

                                break;
                            }

                            try
                            {
                                config?.TaskAwaited();
                                await toAwait;
                            }
                            catch (Exception ex)
                            {
                                if (capturedException == null)
                                {
                                    // if we encountered some exception, tell the remaining tasks to bail
                                    // the next time they check commonState
                                    commonState.SetFinishedEnumerating();

                                    // save the exception so we can rethrow it later
                                    capturedException = ExceptionDispatchInfo.Capture(ex);
                                }
                            }

                            // advance, moving to the next chunk if we've hit that limit
                            toAwaitIndex++;

                            if (toAwaitIndex == runningTasks.Length - 1)
                            {
                                object[] oldChunk = runningTasks;

                                runningTasks = (object[])runningTasks[runningTasks.Length - 1];
                                toAwaitIndex = 0;

                                // we're done with this, let some other caller reuse it immediately
                                ReturnRentedArray(config, oldChunk, oldChunk.Length);
                            }
                        }

                        // fault, if any task failed, after we've finished cleaning up
                        capturedException?.Throw();
                    }
                }
            }
            finally
            {
                // cleanup if something went wrong while these were still rented
                //
                // this can basically only happen if the enumerator itself faults
                // which is unlikely, but far from impossible

                ReturnRentedArray(config, currentBatch, BatchLimit);

                while (runningTasks != null)
                {
                    object[] oldChunk = runningTasks;

                    runningTasks = (object[])runningTasks[runningTasks.Length - 1];

                    ReturnRentedArray(config, oldChunk, oldChunk.Length);
                }
            }
        }