diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/SemaphoreSlim.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/SemaphoreSlim.cs index a530c22b051d20..6238320c2d03b1 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/SemaphoreSlim.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/SemaphoreSlim.cs @@ -40,7 +40,9 @@ public class SemaphoreSlim : IDisposable // The number of synchronously waiting threads, it is set to zero in the constructor and increments before blocking the // threading and decrements it back after that. It is used as flag for the release call to know if there are // waiting threads in the monitor or not. - private int m_waitCount; + // Volatile so the lock-free WaitAsync fast path observes the increment via release/acquire pairing rather than + // depending on the lock release of the writer (which the fast path bypasses). + private volatile int m_waitCount; /// /// This is used to help prevent waking more waiters than necessary. It's not perfect and sometimes more waiters than @@ -57,7 +59,9 @@ public class SemaphoreSlim : IDisposable private volatile ManualResetEvent? m_waitHandle; // Head of list representing asynchronous waits on the semaphore. - private TaskNode? m_asyncHead; + // Volatile for the same reason as m_waitCount: the lock-free WaitAsync fast path reads it without the lock + // and must see writes published by the lock-holding enqueue/dequeue paths. + private volatile TaskNode? m_asyncHead; // Tail of list representing asynchronous waits on the semaphore. private TaskNode? m_asyncTail; @@ -106,9 +110,23 @@ public WaitHandle AvailableWaitHandle // lock the count to avoid multiple threads initializing the handle if it is null lock (m_lockObjAndDisposed) { - // The initial state for the wait handle is true if the count is greater than zero - // false otherwise - m_waitHandle ??= new ManualResetEvent(m_currentCount != 0); + if (m_waitHandle is null) + { + // Publish the handle in the unsignaled state first, then reflect the current count. + // Once m_waitHandle is non-null, the lock-free WaitAsync fast path is excluded (it gates + // on m_waitHandle being null). The barrier prevents the m_currentCount read from being + // reordered before the publish on weakly-ordered architectures: without it, a concurrent + // fast-path CAS that already happened could be missed here, leaving the handle Set when + // count == 0. Any fast path that completes between the publish and the count read is + // covered by its own post-CAS recovery branch. + var handle = new ManualResetEvent(false); + m_waitHandle = handle; + Interlocked.MemoryBarrier(); + if (m_currentCount > 0) + { + handle.Set(); + } + } } } @@ -373,42 +391,49 @@ private bool WaitCore(long millisecondsTimeout, CancellationToken cancellationTo // There are no async waiters, so we can proceed with normal synchronous waiting. else { - // If the count > 0 we are good to move on. - // If not, then wait if we were given allowed some wait duration - - OperationCanceledException? oce = null; - - if (m_currentCount == 0) + // Loop to handle the case where the lock-free WaitAsync fast path raced and decremented the + // count between our wait/check and TryDecrementCount. With m_waitCount visibly > 0 the fast + // path defers, so the loop typically runs once; the residual race during m_waitCount's + // publication makes the retry necessary for correctness. + while (true) { - if (millisecondsTimeout == 0) + OperationCanceledException? oce = null; + bool timedOut = false; + if (m_currentCount == 0) { - return false; + if (millisecondsTimeout == 0) + { + return false; + } + + // Prepare for the main wait... + // wait until the count becomes greater than zero or the timeout is expired + try + { + timedOut = !WaitUntilCountOrTimeout(millisecondsTimeout, startTime, cancellationToken); + } + catch (OperationCanceledException e) { oce = e; } } - // Prepare for the main wait... - // wait until the count become greater than zero or the timeout is expired - try + // Now try to acquire. We prioritize acquisition over cancellation/timeout so that we don't + // lose any counts when there are asynchronous waiters in the mix. Asynchronous waiters + // defer to synchronous waiters in priority, which means that if it's possible an asynchronous + // waiter didn't get released because a synchronous waiter was present, we need to ensure + // that synchronous waiter succeeds so that they have a chance to release. + if (TryDecrementCount() > 0) { - waitSuccessful = WaitUntilCountOrTimeout(millisecondsTimeout, startTime, cancellationToken); + waitSuccessful = true; + break; } - catch (OperationCanceledException e) { oce = e; } - } - // Now try to acquire. We prioritize acquisition over cancellation/timeout so that we don't - // lose any counts when there are asynchronous waiters in the mix. Asynchronous waiters - // defer to synchronous waiters in priority, which means that if it's possible an asynchronous - // waiter didn't get released because a synchronous waiter was present, we need to ensure - // that synchronous waiter succeeds so that they have a chance to release. - Debug.Assert(!waitSuccessful || m_currentCount > 0, - "If the wait was successful, there should be count available."); - if (m_currentCount > 0) - { - waitSuccessful = true; - m_currentCount--; - } - else if (oce is not null) - { - throw oce; + if (oce is not null) + { + throw oce; + } + if (timedOut) + { + break; + } } // Exposing wait handle which is lazily initialized if needed @@ -678,12 +703,37 @@ private Task WaitAsyncCore(long millisecondsTimeout, CancellationToken can if (cancellationToken.IsCancellationRequested) return Task.FromCanceled(cancellationToken); + // Fast path: try a lock-free acquire; falls through to the lock if it fails. + // Skipped when m_waitHandle is non-null to keep its state consistent under the lock. + if (m_waitHandle is null) + { + int current = m_currentCount; + // Best-effort waiter checks (m_asyncHead and m_waitCount are volatile, so plain reads + // are acquire-ordered): they may be updated after this read, but the CAS will fail if + // m_currentCount was concurrently decremented. + if (current > 0 + && m_asyncHead is null + && m_waitCount == 0 + && Interlocked.CompareExchange(ref m_currentCount, current - 1, current) == current) + { + // Handle the rare race where AvailableWaitHandle was initialized concurrently. + if (current == 1 && m_waitHandle is not null) + { + lock (m_lockObjAndDisposed) + { + if (m_waitHandle is not null && m_currentCount == 0) + m_waitHandle.Reset(); + } + } + return Task.FromResult(true); + } + } + lock (m_lockObjAndDisposed) { // If there are counts available, allow this waiter to succeed. - if (m_currentCount > 0) + if (TryDecrementCount() > 0) { - --m_currentCount; if (m_waitHandle is not null && m_currentCount == 0) m_waitHandle.Reset(); return Task.FromResult(true); } @@ -759,6 +809,21 @@ private bool RemoveAsyncWaiter(TaskNode task) return wasInList; } + /// + /// Atomically decrements if it is positive, using a CAS loop + /// rather than a plain decrement because the lock-free fast path in + /// can decrement concurrently without holding the lock. + /// + /// The pre-decrement value. A return value of 0 means no count was available. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int TryDecrementCount() + { + int count = m_currentCount; + while (count > 0 && Interlocked.CompareExchange(ref m_currentCount, count - 1, count) != count) + count = m_currentCount; + return count; + } + /// Performs the asynchronous wait. /// The asynchronous waiter. /// The timeout. @@ -842,21 +907,36 @@ public int Release(int releaseCount) lock (m_lockObjAndDisposed) { - // Read the m_currentCount into a local variable to avoid unnecessary volatile accesses inside the lock. - int currentCount = m_currentCount; - returnCount = currentCount; - - // If the release count would result exceeding the maximum count, throw SemaphoreFullException. - if (m_maxCount - currentCount < releaseCount) + // Snapshot the live count. A lock-free WaitAsync fast path can decrement m_currentCount + // concurrently (it bypasses this lock); nothing increments it concurrently (every increment + // path holds this lock). So the real count can only be <= this snapshot until we update it. + int observed = m_currentCount; + + // Validate against m_maxCount. Re-read on a mismatch so a racing fast-path decrement (which + // only lowers the real count) can't make us throw SemaphoreFullException spuriously off a + // stale, too-high snapshot. Because only decrements race, once observed + releaseCount <= + // m_maxCount holds for a snapshot >= the real count, the real count + releaseCount can't + // exceed m_maxCount either, so the bound we enforce on the atomic add below still holds. + while (m_maxCount - observed < releaseCount) { - throw new SemaphoreFullException(); + int reread = m_currentCount; + if (reread == observed) + { + throw new SemaphoreFullException(); + } + observed = reread; } + returnCount = observed; - // Increment the count by the actual release count - currentCount += releaseCount; + // Compute the post-release count in a LOCAL only. We must never store this inflated value into + // m_currentCount: it includes permits earmarked for the waiters released below, and the + // lock-free fast path (which reads m_currentCount without the lock) would observe and steal + // them in the window before we corrected the count. Instead we apply only the net delta once, + // atomically, at the end. Whenever waiters are present the count is 0 and no fast path can be + // racing (it requires count > 0 and no waiters), so this snapshot is stable here. + int currentCount = observed + releaseCount; - // Signal to any synchronous waiters, taking into account how many waiters have previously been pulsed to wake - // but have not yet woken + // Signal synchronous waiters, accounting for those already pulsed but not yet woken. int waitCount = m_waitCount; Debug.Assert(m_countOfWaitersPulsedToWake <= waitCount); int waitersToNotify = Math.Min(currentCount, waitCount) - m_countOfWaitersPulsedToWake; @@ -884,31 +964,38 @@ public int Release(int releaseCount) // asynchronous waiters, we assume that all synchronous waiters will eventually // acquire the semaphore. That could be a faulty assumption if those synchronous // waits are canceled, but the wait code path will handle that. + // Permits handed to async waiters go straight to their tasks rather than into + // m_currentCount, so they're excluded from the net delta applied below. + int asyncReleased = 0; if (m_asyncHead is not null) { Debug.Assert(m_asyncTail is not null, "tail should not be null if head isn't null"); int maxAsyncToRelease = currentCount - waitCount; - while (maxAsyncToRelease > 0 && m_asyncHead is not null) + while (asyncReleased < maxAsyncToRelease && m_asyncHead is not null) { - --currentCount; - --maxAsyncToRelease; + ++asyncReleased; - // Get the next async waiter to release and queue it to be completed TaskNode waiterTask = m_asyncHead; RemoveAsyncWaiter(waiterTask); // ensures waiterTask.Next/Prev are null waiterTask.TrySetResult(result: true); } + currentCount -= asyncReleased; } - m_currentCount = currentCount; - // Exposing wait handle if it is not null - if (m_waitHandle is not null && returnCount == 0 && currentCount > 0) + // Apply the net change (permits released minus those handed straight to async waiters) in a + // single atomic add. A relative add (not an absolute store) folds in any fast-path decrements + // that raced since we snapshotted, and we never publish a count above the number of genuinely + // free permits, so the fast path can never observe a permit reserved for a waiter. The + // pre-validated snapshot bounds the result at or below m_maxCount. + int delta = releaseCount - asyncReleased; + int newCount = delta != 0 ? Interlocked.Add(ref m_currentCount, delta) : observed; + + if (m_waitHandle is not null && observed == 0 && newCount > 0) { m_waitHandle.Set(); } } - // And return the count return returnCount; } diff --git a/src/libraries/System.Threading/tests/SemaphoreSlimTests.cs b/src/libraries/System.Threading/tests/SemaphoreSlimTests.cs index 194dbbebd8df8a..d0d6150b748539 100644 --- a/src/libraries/System.Threading/tests/SemaphoreSlimTests.cs +++ b/src/libraries/System.Threading/tests/SemaphoreSlimTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading.Tasks; using Microsoft.DotNet.RemoteExecutor; @@ -618,6 +619,194 @@ public static void TestConcurrentWaitAndWaitAsync(int syncWaiters, int asyncWait Task.WaitAll(tasks); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task WaitAsync_AvailableWaitHandle_ConcurrentInit_StaysConsistent() + { + // The race only fires during the first AvailableWaitHandle access on a given semaphore + // (after init, m_waitHandle is non-null and the WaitAsync fast path is excluded). Use a + // fresh semaphore per iteration so each iteration is a real attempt at the race. + const int Iterations = 1_000; + + for (int i = 0; i < Iterations; i++) + { + var sem = new SemaphoreSlim(1, 1); + Task accessor = Task.Run(() => sem.AvailableWaitHandle); + + await sem.WaitAsync(); + await accessor; + + // Count is 0; the handle must not be signaled. + Assert.False(sem.AvailableWaitHandle.WaitOne(0)); + + sem.Release(); + sem.Dispose(); + } + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task WaitAsync_ConcurrentFastPath_NeverUnderflows() + { + const int Threads = 16, Iterations = 1_000; + using var sem = new SemaphoreSlim(1, 1); + + await Task.WhenAll(Enumerable.Range(0, Threads).Select(_ => Task.Run(async () => + { + for (int i = 0; i < Iterations; i++) + { + if (await sem.WaitAsync(0)) + sem.Release(); + } + }))); + + // Every successful WaitAsync(0) is paired with a Release; final count must be exactly 1. + Assert.Equal(1, sem.CurrentCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task WaitCore_ConcurrentWithWaitAsyncFastPath_CountNeverCorrupted() + { + // Stresses the CAS loop in WaitCore, which exists because the lock-free WaitAsync + // fast path can decrement m_currentCount without holding the lock. + const int Workers = 8, Iterations = 500; + using var sem = new SemaphoreSlim(1, 1); + var tasks = new Task[Workers * 2]; + + for (int i = 0; i < Workers; i++) + { + tasks[i] = Task.Run(() => + { + for (int j = 0; j < Iterations; j++) + { + sem.Wait(); + sem.Release(); + } + }); + } + for (int i = Workers; i < Workers * 2; i++) + { + tasks[i] = Task.Run(async () => + { + for (int j = 0; j < Iterations; j++) + { + await sem.WaitAsync(); + sem.Release(); + } + }); + } + + await Task.WhenAll(tasks); + Assert.Equal(1, sem.CurrentCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task Release_BulkRelease_ConcurrentWithFastPath_CountStaysCorrect() + { + const int Workers = 8, Iterations = 500; + const int TotalPermits = Workers * Iterations; + using var sem = new SemaphoreSlim(0); + + Task[] consumers = Enumerable.Range(0, Workers).Select(_ => Task.Run(async () => + { + for (int i = 0; i < Iterations; i++) + { + await sem.WaitAsync(); + } + })).ToArray(); + + Task producer = Task.Run(() => + { + int remaining = TotalPermits; + while (remaining >= 2) + { + sem.Release(2); + remaining -= 2; + } + if (remaining > 0) + { + sem.Release(); + } + }); + + await Task.WhenAll(consumers); + await producer; + Assert.Equal(0, sem.CurrentCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task Release_AsyncWaiterHandoff_RacingFastPath_NeverExceedsMaxCount() + { + // Regression test for the lock-free fast path racing Release's async-waiter handoff. + // Release hands permits to async waiters directly, removing each from the waiter list. If the + // released count is briefly published into m_currentCount before being reconciled, the lock-free + // WaitAsync fast path (which reads m_currentCount without the lock) can steal a permit already + // earmarked for a waiter in that window, double-acquiring and corrupting the count until a later + // Release overshoots and throws SemaphoreFullException. Every acquire here is paired with a + // Release, so the single permit must round-trip cleanly with no exception and a final count of 1. + const int Workers = 8, Iterations = 2_000; + using var sem = new SemaphoreSlim(1, 1); + + // Blocking acquires: when several land while the permit is held they queue as async waiters, + // exercising Release's direct async-waiter handoff path. + Task[] waiters = Enumerable.Range(0, Workers).Select(_ => Task.Run(async () => + { + for (int i = 0; i < Iterations; i++) + { + await sem.WaitAsync(); + sem.Release(); + } + })).ToArray(); + + // Non-blocking acquires hammer the fast path, racing the handoff window above. + Task[] pollers = Enumerable.Range(0, Workers).Select(_ => Task.Run(async () => + { + for (int i = 0; i < Iterations; i++) + { + if (await sem.WaitAsync(0)) + sem.Release(); + } + })).ToArray(); + + // A SemaphoreFullException from any worker propagates through Task.WhenAll and fails the test. + await Task.WhenAll(waiters.Concat(pollers)); + Assert.Equal(1, sem.CurrentCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsMultithreadingSupported))] + public static async Task WaitAsync_CancellationRacingAcquire_NoCountCorruption() + { + // Race cancellation against concurrent WaitAsync acquisitions on a shared semaphore. Multiple + // workers contend for a single permit while each token is cancelled from a separate thread, so + // the cancellation genuinely races the wait instead of always preceding it (which, uncontended, + // would let the fast path complete synchronously before the cancel and never exercise the race). + // Every wait either acquires the permit (and releases it) or is canceled, so the count must + // never be corrupted regardless of which outcome wins the race. + const int Workers = 8, Iterations = 1_000; + using var sem = new SemaphoreSlim(1, 1); + + await Task.WhenAll(Enumerable.Range(0, Workers).Select(_ => Task.Run(async () => + { + for (int i = 0; i < Iterations; i++) + { + var cts = new CancellationTokenSource(); + + // Cancel from the thread pool (no per-iteration Task allocation) so it races the + // WaitAsync below rather than always preceding it. + ThreadPool.UnsafeQueueUserWorkItem(static s => ((CancellationTokenSource)s).Cancel(), cts); + + try + { + // Completes (acquires the permit) on success, throws on cancellation. + await sem.WaitAsync(cts.Token); + sem.Release(); + } + catch (OperationCanceledException) { } + } + }))); + + // No acquire was leaked or double-counted, so the single permit is back. + Assert.Equal(1, sem.CurrentCount); + } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] public void WaitAsync_Timeout_NoUnhandledException() {