rd-net/Lifetimes/Collections/Viewable/ViewableConcurrentSet.cs (172 lines of code) (raw):

using System; using System.Collections; using System.Collections.Generic; using JetBrains.Diagnostics; using JetBrains.Lifetimes; using JetBrains.Util.Internal; namespace JetBrains.Collections.Viewable; public class ViewableConcurrentSet<T> : IViewableConcurrentSet<T> where T : notnull { private readonly Signal<VersionedData> mySignal = new(); private Dictionary<T, LifetimeDefinition> myMap; private readonly object myLocker = new(); private int myCount; private int myAddVersion; private int myIsUnderReadingCount; public int Count => Memory.VolatileRead(ref myCount); public ViewableConcurrentSet(IEqualityComparer<T>? comparer = null) { myMap = new Dictionary<T, LifetimeDefinition>(comparer); } public bool Add(T value) { LifetimeDefinition? definition; int version; lock (myLocker) { var map = GetOrCloneMapNoLock(); if (map.TryGetValue(value, out definition) && definition.Lifetime.IsAlive) return false; definition = new LifetimeDefinition(); map[value] = definition; version = ++myAddVersion; myCount++; } mySignal.Fire(new VersionedData(definition.Lifetime, value, version)); return true; } public bool Remove(T value) { LifetimeDefinition? definitionToRemove; lock (myLocker) { var map = GetOrCloneMapNoLock(); if (!map.TryGetValue(value, out definitionToRemove)) return false; } definitionToRemove.Terminate(); lock (myLocker) { var map = GetOrCloneMapNoLock(); if (!map.TryGetValue(value, out var definition) || definition != definitionToRemove) return false; map.Remove(value); myCount--; } return true; } public bool Contains(T value) { return TryGetLifetime(value, out var lifetime) && lifetime.IsAlive; } public bool TryGetLifetime(T value, out Lifetime lifetime) { lock (myLocker) { if (myMap.TryGetValue(value, out var definition)) { lifetime = definition.Lifetime; return true; } lifetime = Lifetime.Terminated; return false; } } private Dictionary<T, LifetimeDefinition> GetOrCloneMapNoLock() { var map = myMap; if (myIsUnderReadingCount > 0) { map = new(map); myIsUnderReadingCount = 0; myMap = map; return map; } return map; } public void View(Lifetime lifetime, Action<Lifetime, T> action) { Dictionary<T, LifetimeDefinition> map; lock (myLocker) { map = myMap; var version = myAddVersion; mySignal.Advise(lifetime, versionedData => { if (versionedData.Version <= version) return; var value = versionedData.Value; var newLifetime = versionedData.Lifetime.Intersect(lifetime); if (newLifetime.IsNotAlive) return; action(newLifetime, value); }); if (map.Count == 0) return; myIsUnderReadingCount++; } foreach (var (value, definition) in map) { var newLifetime = definition.Lifetime.Intersect(lifetime); if (newLifetime.IsNotAlive) continue; try { action(newLifetime, value); } catch (Exception e) { Log.Root.Error(e); } } lock (myLocker) { if (myMap == map) { var count = myIsUnderReadingCount--; Assertion.Assert(count >= 0); } } } private readonly record struct VersionedData(Lifetime Lifetime, T Value, int Version) { public readonly Lifetime Lifetime = Lifetime; public readonly T Value = Value; public readonly int Version = Version; } private readonly struct ReadCookie : IDisposable { private readonly ViewableConcurrentSet<T> mySet; public Dictionary<T, LifetimeDefinition> Map { get; } public ReadCookie(ViewableConcurrentSet<T> set) { mySet = set; lock (set.myLocker) { Map = mySet.myMap; mySet.myIsUnderReadingCount++; } } public void Dispose() { lock (mySet.myLocker) { if (Map == mySet.myMap) { var count = mySet.myIsUnderReadingCount--; Assertion.Assert(count >= 0); } } } } public IEnumerator<T> GetEnumerator() { using var cookie = new ReadCookie(this); foreach (var (key, definition) in cookie.Map) { if (definition.Lifetime.IsAlive) yield return key; } } IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); }