diff --git a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs index fa92b754c..8bd72aab3 100644 --- a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs +++ b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs @@ -20,15 +20,33 @@ public class NoMessagePumpSyncContext : SynchronizationContext /// private static readonly SynchronizationContext DefaultInstance = new NoMessagePumpSyncContext(); + private readonly SynchronizationContext? underlyingSyncContext; + /// /// Initializes a new instance of the class. /// + /// + /// When using this constructor, uses the default + /// behavior and schedules work on the thread pool, while uses the default + /// behavior and invokes the callback synchronously on the calling thread. + /// public NoMessagePumpSyncContext() { // This is required so that our override of Wait is invoked. this.SetWaitNotificationRequired(); } + /// + /// Initializes a new instance of the class. + /// + /// The that should handle calls to and . + public NoMessagePumpSyncContext(SynchronizationContext underlyingSyncContext) + : this() + { + Requires.NotNull(underlyingSyncContext, nameof(underlyingSyncContext)); + this.underlyingSyncContext = underlyingSyncContext; + } + /// /// Gets a shared instance of this class. /// @@ -37,6 +55,36 @@ public static SynchronizationContext Default get { return DefaultInstance; } } + /// + public override void Send(SendOrPostCallback d, object? state) + { + Requires.NotNull(d, nameof(d)); + + if (this.underlyingSyncContext is { } underlying) + { + underlying.Send(d, state); + } + else + { + base.Send(d, state); + } + } + + /// + public override void Post(SendOrPostCallback d, object? state) + { + Requires.NotNull(d, nameof(d)); + + if (this.underlyingSyncContext is { } underlying) + { + underlying.Post(d, state); + } + else + { + base.Post(d, state); + } + } + /// /// Synchronously blocks without a message pump. /// diff --git a/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs index 1c95be9d3..245006e50 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs @@ -3,6 +3,7 @@ using System; using System.Threading; +using System.Threading.Tasks; /// /// Tests for . @@ -36,6 +37,109 @@ public void Default_IsNoMessagePumpSyncContext() Assert.IsType(NoMessagePumpSyncContext.Default); } + /// + /// Verifies that schedules work on the thread pool + /// when no underlying sync context is provided. + /// + [Fact] + public async Task Post_DefaultConstructor_ExecutesOnThreadPool() + { + NoMessagePumpSyncContext sc = new(); + TaskCompletionSource tcs = new(); + sc.Post(_ => tcs.SetResult(Thread.CurrentThread.IsThreadPoolThread), null); + Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken)); + } + + /// + /// Verifies that executes work synchronously + /// on the calling thread when no underlying sync context is provided. + /// + [Fact] + public void Send_DefaultConstructor_ExecutesInlineOnCallingThread() + { + NoMessagePumpSyncContext sc = new(); + int callingThreadId = Thread.CurrentThread.ManagedThreadId; + int? callbackThreadId = null; + bool callbackInvoked = false; + + sc.Send( + _ => + { + callbackInvoked = true; + callbackThreadId = Thread.CurrentThread.ManagedThreadId; + }, + null); + + Assert.True(callbackInvoked); + Assert.Equal(callingThreadId, callbackThreadId); + } + + /// + /// Verifies that throws + /// when a null underlying context is passed. + /// + [Fact] + public void Constructor_WithNullUnderlyingContext_Throws() + { + Assert.Throws(() => new NoMessagePumpSyncContext(null!)); + } + + /// + /// Verifies that rejects a null callback before + /// delegating to the underlying sync context. + /// + [Fact] + public void Post_WithNullCallback_Throws() + { + ThrowingSyncContext underlying = new(); + NoMessagePumpSyncContext sc = new(underlying); + + Assert.Throws(() => sc.Post(null!, null)); + Assert.False(underlying.PostInvoked); + } + + /// + /// Verifies that rejects a null callback before + /// delegating to the underlying sync context. + /// + [Fact] + public void Send_WithNullCallback_Throws() + { + ThrowingSyncContext underlying = new(); + NoMessagePumpSyncContext sc = new(underlying); + + Assert.Throws(() => sc.Send(null!, null)); + Assert.False(underlying.SendInvoked); + } + + /// + /// Verifies that delegates to the underlying + /// sync context when one is provided. + /// + [Fact] + public async Task Post_WithUnderlyingContext_DelegatesToUnderlying() + { + TaskCompletionSource tcs = new(); + RecordingPostSyncContext underlying = new(posted: _ => tcs.SetResult(true)); + NoMessagePumpSyncContext sc = new(underlying); + sc.Post(_ => { }, null); + Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken)); + } + + /// + /// Verifies that delegates to the underlying + /// sync context when one is provided. + /// + [Fact] + public void Send_WithUnderlyingContext_DelegatesToUnderlying() + { + bool sendInvoked = false; + RecordingSendSyncContext underlying = new(sent: _ => sendInvoked = true); + NoMessagePumpSyncContext sc = new(underlying); + sc.Send(_ => { }, null); + Assert.True(sendInvoked); + } + #if NETFRAMEWORK /// /// Establishes the baseline: on a plain STA thread without a special synchronization context, @@ -77,4 +181,50 @@ public void Wait_BlocksComRpcCalls() } } #endif + + /// + /// A that invokes a callback when is called. + /// + private class RecordingPostSyncContext(Action posted) : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object? state) + { + posted(d); + base.Post(d, state); + } + } + + /// + /// A that invokes a callback when is called. + /// + private class RecordingSendSyncContext(Action sent) : SynchronizationContext + { + public override void Send(SendOrPostCallback d, object? state) + { + sent(d); + base.Send(d, state); + } + } + + /// + /// A that records whether or were invoked. + /// + private class ThrowingSyncContext : SynchronizationContext + { + public bool PostInvoked { get; private set; } + + public bool SendInvoked { get; private set; } + + public override void Post(SendOrPostCallback d, object? state) + { + this.PostInvoked = true; + throw new InvalidOperationException(); + } + + public override void Send(SendOrPostCallback d, object? state) + { + this.SendInvoked = true; + throw new InvalidOperationException(); + } + } }