rd-net/Lifetimes/Collections/JetPriorityQueue.cs (298 lines of code) (raw):

using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using JetBrains.Annotations; using JetBrains.Diagnostics; using JetBrains.Lifetimes; using JetBrains.Util; namespace JetBrains.Collections { /// <summary> /// JetBrains interface of priority queue data structure. /// </summary> /// <typeparam name="T"></typeparam> public interface IPriorityQueue<T> : ICollection<T> , IReadOnlyCollection<T> { new int Count { get; } bool TryExtract(out T? res); bool TryPeek(out T? res); } /// <summary> /// JetBrains implementation of priority queue data structure. /// </summary> /// <typeparam name="T"></typeparam> public class JetPriorityQueue<T> : IPriorityQueue<T> { public const int DefaultCapacity = 10; private readonly List<T?> myStorage; private readonly List<long> myVersions; private readonly IComparer<T?> myComparer; private long myVersionAcc; public JetPriorityQueue(int initialCapacity = DefaultCapacity, IComparer<T?>? comparer = null) { if (initialCapacity <= 0) initialCapacity = DefaultCapacity; myStorage = new List<T?>(initialCapacity + 1) { default(T) }; //first elem is always false to simplify `left` and `right` myVersions = new List<long>(initialCapacity + 1) {0}; myComparer = comparer ?? Comparer<T?>.Default; } #region ICollection implementation public IEnumerator<T> GetEnumerator() { var enumerator = myStorage.GetEnumerator(); enumerator.MoveNext(); return enumerator; } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public void Add(T item) { var idx = myStorage.Count; myStorage.Add(item); myVersions.Add(++myVersionAcc); HeapUp(idx); } public void Clear() { myStorage.Clear(); myVersions.Clear(); myStorage.Add(default(T)); myVersions.Add(0); } public bool Contains(T item) { return myStorage.IndexOf(item) > 0; } public void CopyTo(T[] array, int arrayIndex) { myStorage.CopyTo(1, array, arrayIndex, Count); } public bool Remove(T item) { throw new InvalidOperationException(); } public int Count => myStorage.Count - 1; public bool IsReadOnly => false; #endregion #region Priority related methods public bool TryExtract(out T? res) { if (!TryPeek(out res)) return false; var last = myStorage.Count - 1; myStorage[1] = myStorage[last]; myVersions[1] = myVersions[last]; //todo default list implementation calls Array.Copy with zero size that is not optimal behaviour (e.g. use LocalList here) myStorage.RemoveAt(last); myVersions.RemoveAt(last); if (last > 1) HeapDown(1); return true; } public bool TryPeek(out T? res) { if (myStorage.Count <= 1) { res = default (T); return false; } res = myStorage[1]; return true; } #endregion #region private Helpers private void Swap(ref int i, int j) { var s = myStorage[i]; myStorage[i] = myStorage[j]; myStorage[j] = s; var v = myVersions[i]; myVersions[i] = myVersions[j]; myVersions[j] = v; i = j; } private int Compare(int left, int right) { if (left == right) return 0; var cmp1 = myComparer.Compare(myStorage[left], myStorage[right]); if (cmp1 != 0) return cmp1; var cmp2 = myVersions[left] - myVersions[right]; if (Mode.IsAssertion) Assertion.Assert(cmp2 != 0, "Equal versions for indices {0}, {1}, version = {2}", left, right, myVersions[left]); return cmp2 > 0 ? 1 : -1; } private void HeapDown(int idx) { if (Mode.IsAssertion) Assertion.Assert(idx >= 1 && idx < myStorage.Count, "Index {0} is not in range [1, {1})", idx, myStorage.Count); int n = myStorage.Count; int left = (idx << 1) | 0; while (left < n) { int nxt; if (left == n-1) nxt = left; else { int right = (idx << 1) | 1; nxt = Compare(left, right) < 0 ? left : right; } if (Compare(idx, nxt) <= 0) break; Swap(ref idx, nxt); left = (idx << 1) | 0; } } private void HeapUp(int idx) { if (Mode.IsAssertion) Assertion.Assert(idx >= 1 && idx < myStorage.Count, "Index {0} is not in range [1, {1})", idx, myStorage.Count); while (idx > 1 && Compare(idx, idx >> 1) < 0) { Swap(ref idx, idx >> 1); } } #endregion } /// <summary> /// Thread-safe implementation of priority queue data structure. /// </summary> /// <typeparam name="T"></typeparam> public class BlockingPriorityQueue<T> : IPriorityQueue<T> { private readonly Lifetime myLifetime; private readonly JetPriorityQueue<T> myQueue; private readonly object mySentry = new object(); public BlockingPriorityQueue(Lifetime lifetime, int initialCapacity = JetPriorityQueue<T>.DefaultCapacity, IComparer<T?>? comparer = null) { myLifetime = lifetime; myQueue = new JetPriorityQueue<T>(initialCapacity, comparer); lifetime.OnTermination(() => { lock (mySentry) { Clear(); Monitor.PulseAll(mySentry); //to wake up all waiters } }); } public IEnumerator<T> GetEnumerator() { return ((IEnumerable<T>)ToArray(/* required for thread safety */)).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public void Add(T item) { Enqueue(item); } public void Clear() { lock (mySentry) myQueue.Clear(); } public bool Contains(T item) { lock (mySentry) return myQueue.Contains(item); } public void CopyTo(T[] array, int arrayIndex) { lock (mySentry) myQueue.CopyTo(array, arrayIndex); } public bool Remove(T item) { lock (mySentry) return myQueue.Remove(item); } public int Count { get { lock (mySentry) return myQueue.Count; } } public bool IsReadOnly { get { lock (mySentry) return myQueue.IsReadOnly; } } public bool TryExtract(out T? res) { lock (mySentry) return myQueue.TryExtract(out res); } public bool TryPeek(out T? res) { lock (mySentry) return myQueue.TryPeek(out res); } [PublicAPI] public bool TryExtract(out T? res, int intervalMs) { lock (mySentry) { var localIntervalMs = intervalMs; var stopwatch = LocalStopwatch.StartNew(); do { if (myQueue.TryExtract(out res)) return true; if (myLifetime.Status >= LifetimeStatus.Terminating) return false; if (!Monitor.Wait(mySentry, localIntervalMs)) break; var elapsed = stopwatch.ElapsedMilliseconds; if (elapsed >= intervalMs) break; localIntervalMs = intervalMs - (int)elapsed; } while (true); return myQueue.TryExtract(out res); } } [PublicAPI] public bool TryPeek(out T? res, int intervalMs) { lock (mySentry) { var localIntervalMs = intervalMs; var stopwatch = LocalStopwatch.StartNew(); do { if (myQueue.TryPeek(out res)) return true; if (myLifetime.Status >= LifetimeStatus.Terminating) return false; if (!Monitor.Wait(mySentry, localIntervalMs)) break; var elapsed = stopwatch.ElapsedMilliseconds; if (elapsed >= intervalMs) break; localIntervalMs = intervalMs - (int)elapsed; } while (true); return myQueue.TryPeek(out res); } } /// <summary> /// Returns first element from queue or waits until it appears. In case of lifetime termination throws PCE. /// </summary> /// <returns>First element in queue</returns> [PublicAPI] public T? ExtractOrBlock() { lock (mySentry) { while (true) { if (myLifetime.Status >= LifetimeStatus.Terminating) throw new OperationCanceledException(); if (myQueue.TryExtract(out var res)) return res; //no luck, wait for value Monitor.Wait(mySentry); } } } /// <summary> /// Enqueues an item and returns the total number of items in the queue right after enqueueing, in a thread-safe-consistent manner. /// </summary> [PublicAPI] public int Enqueue(T item) { lock (mySentry) { if (myLifetime.Status >= LifetimeStatus.Terminating) return 0; myQueue.Add(item); int count = myQueue.Count; Monitor.Pulse(mySentry); return count; } } /// <summary> /// Copies data to an array, thread-safely. /// </summary> [PublicAPI] public T[] ToArray() { lock(mySentry) return myQueue.ToArray(); } } public static class PriorityQueueEx { /// <summary> /// Same as <see cref="IPriorityQueue{T}.Add"/> /// </summary> /// <param name="queue"></param> /// <param name="val"></param> /// <typeparam name="T"></typeparam> [PublicAPI] public static void Enqueue<T>(this IPriorityQueue<T> queue, T val) { queue.Add(val); } [PublicAPI] public static T? ExtractOrDefault<T>(this IPriorityQueue<T> queue) { return !queue.TryExtract(out var res) ? default(T) : res; } [PublicAPI] public static T? Extract<T>(this IPriorityQueue<T> queue) { if (!queue.TryExtract(out var res)) { throw new InvalidOperationException("Can't extract min, n"); } return res; } [PublicAPI] public static T? Peek<T>(this IPriorityQueue<T> queue) { if (!queue.TryPeek(out var res)) { throw new InvalidOperationException("Can't extract min, n"); } return res; } } }