src/assets/Azure.Core.Shared/TaskExtensions.cs (245 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #nullable disable using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace Azure.Core.Pipeline { internal static class TaskExtensions { public static WithCancellationTaskAwaitable AwaitWithCancellation(this Task task, CancellationToken cancellationToken) => new WithCancellationTaskAwaitable(task, cancellationToken); public static WithCancellationTaskAwaitable<T> AwaitWithCancellation<T>(this Task<T> task, CancellationToken cancellationToken) => new WithCancellationTaskAwaitable<T>(task, cancellationToken); public static WithCancellationValueTaskAwaitable<T> AwaitWithCancellation<T>(this ValueTask<T> task, CancellationToken cancellationToken) => new WithCancellationValueTaskAwaitable<T>(task, cancellationToken); public static T EnsureCompleted<T>(this Task<T> task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. return task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. } public static void EnsureCompleted(this Task task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. } public static T EnsureCompleted<T>(this ValueTask<T> task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. return task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. } public static void EnsureCompleted(this ValueTask task) { #if DEBUG VerifyTaskCompleted(task.IsCompleted); #endif #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. task.GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). Use the TaskExtensions.EnsureCompleted() extension method instead. } public static Enumerable<T> EnsureSyncEnumerable<T>(this IAsyncEnumerable<T> asyncEnumerable) => new Enumerable<T>(asyncEnumerable); public static ConfiguredValueTaskAwaitable<T> EnsureCompleted<T>(this ConfiguredValueTaskAwaitable<T> awaitable, bool async) { if (!async) { #if DEBUG VerifyTaskCompleted(awaitable.GetAwaiter().IsCompleted); #endif } return awaitable; } public static ConfiguredValueTaskAwaitable EnsureCompleted(this ConfiguredValueTaskAwaitable awaitable, bool async) { if (!async) { #if DEBUG VerifyTaskCompleted(awaitable.GetAwaiter().IsCompleted); #endif } return awaitable; } [Conditional("DEBUG")] private static void VerifyTaskCompleted(bool isCompleted) { if (!isCompleted) { if (Debugger.IsAttached) { Debugger.Break(); } // Throw an InvalidOperationException instead of using // Debug.Assert because that brings down nUnit immediately throw new InvalidOperationException("Task is not completed"); } } /// <summary> /// Both <see cref="Enumerable{T}"/> and <see cref="Enumerator{T}"/> are defined as public structs so that foreach can use duck typing /// to call <see cref="Enumerable{T}.GetEnumerator"/> and avoid heap memory allocation. /// Please don't delete this method and don't make these types private. /// </summary> /// <typeparam name="T"></typeparam> public readonly struct Enumerable<T> : IEnumerable<T> { private readonly IAsyncEnumerable<T> _asyncEnumerable; public Enumerable(IAsyncEnumerable<T> asyncEnumerable) => _asyncEnumerable = asyncEnumerable; public Enumerator<T> GetEnumerator() => new Enumerator<T>(_asyncEnumerable.GetAsyncEnumerator()); IEnumerator<T> IEnumerable<T>.GetEnumerator() => new Enumerator<T>(_asyncEnumerable.GetAsyncEnumerator()); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } public readonly struct Enumerator<T> : IEnumerator<T> { private readonly IAsyncEnumerator<T> _asyncEnumerator; public Enumerator(IAsyncEnumerator<T> asyncEnumerator) => _asyncEnumerator = asyncEnumerator; #pragma warning disable AZC0107 // Do not call public asynchronous method in synchronous scope. public bool MoveNext() => _asyncEnumerator.MoveNextAsync().EnsureCompleted(); #pragma warning restore AZC0107 // Do not call public asynchronous method in synchronous scope. public void Reset() => throw new NotSupportedException($"{GetType()} is a synchronous wrapper for {_asyncEnumerator.GetType()} async enumerator, which can't be reset, so IEnumerable.Reset() calls aren't supported."); public T Current => _asyncEnumerator.Current; object IEnumerator.Current => Current; #pragma warning disable AZC0107 // Do not call public asynchronous method in synchronous scope. public void Dispose() => _asyncEnumerator.DisposeAsync().EnsureCompleted(); #pragma warning restore AZC0107 // Do not call public asynchronous method in synchronous scope. } public readonly struct WithCancellationTaskAwaitable { private readonly CancellationToken _cancellationToken; private readonly ConfiguredTaskAwaitable _awaitable; public WithCancellationTaskAwaitable(Task task, CancellationToken cancellationToken) { _awaitable = task.ConfigureAwait(false); _cancellationToken = cancellationToken; } public WithCancellationTaskAwaiter GetAwaiter() => new WithCancellationTaskAwaiter(_awaitable.GetAwaiter(), _cancellationToken); } public readonly struct WithCancellationTaskAwaitable<T> { private readonly CancellationToken _cancellationToken; private readonly ConfiguredTaskAwaitable<T> _awaitable; public WithCancellationTaskAwaitable(Task<T> task, CancellationToken cancellationToken) { _awaitable = task.ConfigureAwait(false); _cancellationToken = cancellationToken; } public WithCancellationTaskAwaiter<T> GetAwaiter() => new WithCancellationTaskAwaiter<T>(_awaitable.GetAwaiter(), _cancellationToken); } public readonly struct WithCancellationValueTaskAwaitable<T> { private readonly CancellationToken _cancellationToken; private readonly ConfiguredValueTaskAwaitable<T> _awaitable; public WithCancellationValueTaskAwaitable(ValueTask<T> task, CancellationToken cancellationToken) { _awaitable = task.ConfigureAwait(false); _cancellationToken = cancellationToken; } public WithCancellationValueTaskAwaiter<T> GetAwaiter() => new WithCancellationValueTaskAwaiter<T>(_awaitable.GetAwaiter(), _cancellationToken); } public readonly struct WithCancellationTaskAwaiter : ICriticalNotifyCompletion { private readonly CancellationToken _cancellationToken; private readonly ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _taskAwaiter; public WithCancellationTaskAwaiter(ConfiguredTaskAwaitable.ConfiguredTaskAwaiter awaiter, CancellationToken cancellationToken) { _taskAwaiter = awaiter; _cancellationToken = cancellationToken; } public bool IsCompleted => _taskAwaiter.IsCompleted || _cancellationToken.IsCancellationRequested; public void OnCompleted(Action continuation) => _taskAwaiter.OnCompleted(WrapContinuation(continuation)); public void UnsafeOnCompleted(Action continuation) => _taskAwaiter.UnsafeOnCompleted(WrapContinuation(continuation)); public void GetResult() { Debug.Assert(IsCompleted); if (!_taskAwaiter.IsCompleted) { _cancellationToken.ThrowIfCancellationRequested(); } _taskAwaiter.GetResult(); } private Action WrapContinuation(in Action originalContinuation) => _cancellationToken.CanBeCanceled ? new WithCancellationContinuationWrapper(originalContinuation, _cancellationToken).Continuation : originalContinuation; } public readonly struct WithCancellationTaskAwaiter<T> : ICriticalNotifyCompletion { private readonly CancellationToken _cancellationToken; private readonly ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter _taskAwaiter; public WithCancellationTaskAwaiter(ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter awaiter, CancellationToken cancellationToken) { _taskAwaiter = awaiter; _cancellationToken = cancellationToken; } public bool IsCompleted => _taskAwaiter.IsCompleted || _cancellationToken.IsCancellationRequested; public void OnCompleted(Action continuation) => _taskAwaiter.OnCompleted(WrapContinuation(continuation)); public void UnsafeOnCompleted(Action continuation) => _taskAwaiter.UnsafeOnCompleted(WrapContinuation(continuation)); public T GetResult() { Debug.Assert(IsCompleted); if (!_taskAwaiter.IsCompleted) { _cancellationToken.ThrowIfCancellationRequested(); } return _taskAwaiter.GetResult(); } private Action WrapContinuation(in Action originalContinuation) => _cancellationToken.CanBeCanceled ? new WithCancellationContinuationWrapper(originalContinuation, _cancellationToken).Continuation : originalContinuation; } public readonly struct WithCancellationValueTaskAwaiter<T> : ICriticalNotifyCompletion { private readonly CancellationToken _cancellationToken; private readonly ConfiguredValueTaskAwaitable<T>.ConfiguredValueTaskAwaiter _taskAwaiter; public WithCancellationValueTaskAwaiter(ConfiguredValueTaskAwaitable<T>.ConfiguredValueTaskAwaiter awaiter, CancellationToken cancellationToken) { _taskAwaiter = awaiter; _cancellationToken = cancellationToken; } public bool IsCompleted => _taskAwaiter.IsCompleted || _cancellationToken.IsCancellationRequested; public void OnCompleted(Action continuation) => _taskAwaiter.OnCompleted(WrapContinuation(continuation)); public void UnsafeOnCompleted(Action continuation) => _taskAwaiter.UnsafeOnCompleted(WrapContinuation(continuation)); public T GetResult() { Debug.Assert(IsCompleted); if (!_taskAwaiter.IsCompleted) { _cancellationToken.ThrowIfCancellationRequested(); } return _taskAwaiter.GetResult(); } private Action WrapContinuation(in Action originalContinuation) => _cancellationToken.CanBeCanceled ? new WithCancellationContinuationWrapper(originalContinuation, _cancellationToken).Continuation : originalContinuation; } private class WithCancellationContinuationWrapper { private Action _originalContinuation; private readonly CancellationTokenRegistration _registration; public WithCancellationContinuationWrapper(Action originalContinuation, CancellationToken cancellationToken) { Action continuation = ContinuationImplementation; _originalContinuation = originalContinuation; _registration = cancellationToken.Register(continuation); Continuation = continuation; } public Action Continuation { get; } private void ContinuationImplementation() { Action originalContinuation = Interlocked.Exchange(ref _originalContinuation, null); if (originalContinuation != null) { _registration.Dispose(); originalContinuation(); } } } } }