diff --git a/src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs b/src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs index 630bc2b6f..50eedf93e 100644 --- a/src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs +++ b/src/Microsoft.VisualStudio.Threading/AsyncLazy`1.cs @@ -38,7 +38,7 @@ public class AsyncLazy /// /// The unique instance identifier. /// - private readonly AsyncLocal recursiveFactoryCheck = new AsyncLocal(); + private AsyncLocal? recursiveFactoryCheck; /// /// The function to invoke to produce the task. @@ -72,6 +72,31 @@ public AsyncLazy(Func> valueFactory, JoinableTaskFactory? joinableTaskFa this.jobFactory = joinableTaskFactory; } + /// + /// Gets a value indicating whether to suppress detection of a value factory depending on itself. + /// + /// The default value is . + /// + /// + /// A value factory that truly depends on itself (e.g. by calling on the same instance) + /// would deadlock, and by default this class will throw an exception if it detects such a condition. + /// However this detection relies on the .NET ExecutionContext, which can flow to "spin off" contexts that are not awaited + /// by the factory, and thus could legally await the result of the value factory without deadlocking. + /// + /// + /// When this flows improperly, it can cause to be thrown, but only when the value factory + /// has not already been completed, leading to a difficult to reproduce race condition. + /// Such a case can be resolved by calling around the non-awaited fork in , + /// or the entire instance can be configured to suppress this check by setting this property to . + /// + /// + /// When this property is set to , the recursive factory check will not be performed, + /// but will still call into + /// if a was provided to the constructor. + /// + /// + public bool SuppressRecursiveFactoryDetection { get; init; } + /// /// Gets a value indicating whether the value factory has been invoked. /// @@ -137,7 +162,7 @@ public bool IsValueFactoryCompleted /// Thrown after is called. public Task GetValueAsync(CancellationToken cancellationToken) { - if (!((this.value is object && this.value.IsCompleted) || this.recursiveFactoryCheck.Value is null)) + if (this.value is not { IsCompleted: true } && this.recursiveFactoryCheck is { Value: not null }) { // PERF: we check the condition and *then* retrieve the string resource only on failure // because the string retrieval has shown up as significant on ETL traces. @@ -183,7 +208,12 @@ public Task GetValueAsync(CancellationToken cancellationToken) } }; - this.recursiveFactoryCheck.Value = RecursiveCheckSentinel; + if (!this.SuppressRecursiveFactoryDetection) + { + Assumes.Null(this.recursiveFactoryCheck); + this.recursiveFactoryCheck = new AsyncLocal() { Value = RecursiveCheckSentinel }; + } + try { if (this.jobFactory is object) @@ -201,7 +231,10 @@ public Task GetValueAsync(CancellationToken cancellationToken) } finally { - this.recursiveFactoryCheck.Value = null; + if (this.recursiveFactoryCheck is not null) + { + this.recursiveFactoryCheck.Value = null; + } } } } @@ -451,7 +484,11 @@ internal RevertRelevance(AsyncLazy owner) Requires.NotNull(owner, nameof(owner)); this.owner = owner; - (this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null); + if (owner.recursiveFactoryCheck is not null) + { + (this.oldCheckValue, owner.recursiveFactoryCheck.Value) = (owner.recursiveFactoryCheck.Value, null); + } + this.joinableRelevance = owner.jobFactory?.Context.SuppressRelevance(); } @@ -460,9 +497,9 @@ internal RevertRelevance(AsyncLazy owner) /// public void Dispose() { - if (this.owner is object) + if (this.owner?.recursiveFactoryCheck is { } check) { - this.owner.recursiveFactoryCheck.Value = this.oldCheckValue; + check.Value = this.oldCheckValue; } this.joinableRelevance?.Dispose(); diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index 1a9888196..e099d0429 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy.DisposeValueAsync() -> System.Thre Microsoft.VisualStudio.Threading.AsyncLazy.IsValueDisposed.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt index 1a9888196..e099d0429 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy.DisposeValueAsync() -> System.Thre Microsoft.VisualStudio.Threading.AsyncLazy.IsValueDisposed.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt index 1a9888196..e099d0429 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy.DisposeValueAsync() -> System.Thre Microsoft.VisualStudio.Threading.AsyncLazy.IsValueDisposed.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index 1a9888196..e099d0429 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -13,4 +13,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy.DisposeValueAsync() -> System.Thre Microsoft.VisualStudio.Threading.AsyncLazy.IsValueDisposed.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file diff --git a/test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs index 2077d81c3..5fac9f95c 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/AsyncLazyTests.cs @@ -6,10 +6,13 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; + using Microsoft; using Microsoft.VisualStudio.Threading; + using Xunit; using Xunit.Abstractions; + using NamedSyncContext = AwaitExtensionsTests.NamedSyncContext; public class AsyncLazyTests : TestBase @@ -748,6 +751,105 @@ async Task FireAndForgetCodeAsync() } } + [Fact] + public async Task SuppressRecursiveFactoryDetection_WithoutJTF() + { + AsyncManualResetEvent allowValueFactoryToFinish = new(); + Task? fireAndForgetTask = null; + AsyncLazy asyncLazy = null!; + asyncLazy = new AsyncLazy( + async delegate + { + fireAndForgetTask = FireAndForgetCodeAsync(); + await allowValueFactoryToFinish; + return 1; + }, + null) + { + SuppressRecursiveFactoryDetection = true, + }; + + bool fireAndForgetCodeAsyncEntered = false; + Task lazyValue = asyncLazy.GetValueAsync(); + Assert.True(fireAndForgetCodeAsyncEntered); + allowValueFactoryToFinish.Set(); + + // Assert that the value factory was allowed to finish. + Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken)); + + // Assert that the fire-and-forget task was allowed to finish and did so without throwing. + Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken)); + + async Task FireAndForgetCodeAsync() + { + fireAndForgetCodeAsyncEntered = true; + return await asyncLazy.GetValueAsync(); + } + } + + [Theory, PairwiseData] + public async Task SuppressRecursiveFactoryDetection_WithJTF(bool suppressWithJTF) + { + JoinableTaskContext? context = this.InitializeJTCAndSC(); + SingleThreadedTestSynchronizationContext.IFrame frame = SingleThreadedTestSynchronizationContext.NewFrame(); + + JoinableTaskFactory? jtf = context.Factory; + AsyncManualResetEvent allowValueFactoryToFinish = new(); + Task? fireAndForgetTask = null; + AsyncLazy asyncLazy = null!; + asyncLazy = new AsyncLazy( + async delegate + { + using (suppressWithJTF ? jtf.Context.SuppressRelevance() : default) + using (suppressWithJTF ? default : asyncLazy.SuppressRelevance()) + { + fireAndForgetTask = FireAndForgetCodeAsync(); + } + + await allowValueFactoryToFinish; + return 1; + }, + jtf) + { + SuppressRecursiveFactoryDetection = true, + }; + + bool fireAndForgetCodeAsyncEntered = false; + bool fireAndForgetCodeAsyncReachedUIThread = false; + jtf.Run(async delegate + { + Task lazyValue = asyncLazy.GetValueAsync(); + Assert.True(fireAndForgetCodeAsyncEntered); + await Task.Delay(AsyncDelay); + Assert.False(fireAndForgetCodeAsyncReachedUIThread); + allowValueFactoryToFinish.Set(); + + // Assert that the value factory was allowed to finish. + Assert.Equal(1, await lazyValue.WithCancellation(this.TimeoutToken)); + }); + + // Run a main thread pump so the fire-and-forget task can finish. + SingleThreadedTestSynchronizationContext.PushFrame(SynchronizationContext.Current!, frame); + + // Assert that the fire-and-forget task was allowed to finish and did so without throwing. + Assert.Equal(1, await fireAndForgetTask!.WithCancellation(this.TimeoutToken)); + + async Task FireAndForgetCodeAsync() + { + fireAndForgetCodeAsyncEntered = true; + + // Yield the caller's thread. + // Resuming will require the main thread, since the caller was on the main thread. + await Task.Yield(); + + fireAndForgetCodeAsyncReachedUIThread = true; + + int result = await asyncLazy.GetValueAsync(); + frame.Continue = false; + return result; + } + } + [Fact] public async Task Dispose_ValueType_Completed() {