diff --git a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs
index fa92b754c..8bd72aab3 100644
--- a/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs
+++ b/src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs
@@ -20,15 +20,33 @@ public class NoMessagePumpSyncContext : SynchronizationContext
///
private static readonly SynchronizationContext DefaultInstance = new NoMessagePumpSyncContext();
+ private readonly SynchronizationContext? underlyingSyncContext;
+
///
/// Initializes a new instance of the class.
///
+ ///
+ /// When using this constructor, uses the default
+ /// behavior and schedules work on the thread pool, while uses the default
+ /// behavior and invokes the callback synchronously on the calling thread.
+ ///
public NoMessagePumpSyncContext()
{
// This is required so that our override of Wait is invoked.
this.SetWaitNotificationRequired();
}
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The that should handle calls to and .
+ public NoMessagePumpSyncContext(SynchronizationContext underlyingSyncContext)
+ : this()
+ {
+ Requires.NotNull(underlyingSyncContext, nameof(underlyingSyncContext));
+ this.underlyingSyncContext = underlyingSyncContext;
+ }
+
///
/// Gets a shared instance of this class.
///
@@ -37,6 +55,36 @@ public static SynchronizationContext Default
get { return DefaultInstance; }
}
+ ///
+ public override void Send(SendOrPostCallback d, object? state)
+ {
+ Requires.NotNull(d, nameof(d));
+
+ if (this.underlyingSyncContext is { } underlying)
+ {
+ underlying.Send(d, state);
+ }
+ else
+ {
+ base.Send(d, state);
+ }
+ }
+
+ ///
+ public override void Post(SendOrPostCallback d, object? state)
+ {
+ Requires.NotNull(d, nameof(d));
+
+ if (this.underlyingSyncContext is { } underlying)
+ {
+ underlying.Post(d, state);
+ }
+ else
+ {
+ base.Post(d, state);
+ }
+ }
+
///
/// Synchronously blocks without a message pump.
///
diff --git a/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs
index 1c95be9d3..245006e50 100644
--- a/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs
+++ b/test/Microsoft.VisualStudio.Threading.Tests/NoMessagePumpSyncContextTests.cs
@@ -3,6 +3,7 @@
using System;
using System.Threading;
+using System.Threading.Tasks;
///
/// Tests for .
@@ -36,6 +37,109 @@ public void Default_IsNoMessagePumpSyncContext()
Assert.IsType(NoMessagePumpSyncContext.Default);
}
+ ///
+ /// Verifies that schedules work on the thread pool
+ /// when no underlying sync context is provided.
+ ///
+ [Fact]
+ public async Task Post_DefaultConstructor_ExecutesOnThreadPool()
+ {
+ NoMessagePumpSyncContext sc = new();
+ TaskCompletionSource tcs = new();
+ sc.Post(_ => tcs.SetResult(Thread.CurrentThread.IsThreadPoolThread), null);
+ Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken));
+ }
+
+ ///
+ /// Verifies that executes work synchronously
+ /// on the calling thread when no underlying sync context is provided.
+ ///
+ [Fact]
+ public void Send_DefaultConstructor_ExecutesInlineOnCallingThread()
+ {
+ NoMessagePumpSyncContext sc = new();
+ int callingThreadId = Thread.CurrentThread.ManagedThreadId;
+ int? callbackThreadId = null;
+ bool callbackInvoked = false;
+
+ sc.Send(
+ _ =>
+ {
+ callbackInvoked = true;
+ callbackThreadId = Thread.CurrentThread.ManagedThreadId;
+ },
+ null);
+
+ Assert.True(callbackInvoked);
+ Assert.Equal(callingThreadId, callbackThreadId);
+ }
+
+ ///
+ /// Verifies that throws
+ /// when a null underlying context is passed.
+ ///
+ [Fact]
+ public void Constructor_WithNullUnderlyingContext_Throws()
+ {
+ Assert.Throws(() => new NoMessagePumpSyncContext(null!));
+ }
+
+ ///
+ /// Verifies that rejects a null callback before
+ /// delegating to the underlying sync context.
+ ///
+ [Fact]
+ public void Post_WithNullCallback_Throws()
+ {
+ ThrowingSyncContext underlying = new();
+ NoMessagePumpSyncContext sc = new(underlying);
+
+ Assert.Throws(() => sc.Post(null!, null));
+ Assert.False(underlying.PostInvoked);
+ }
+
+ ///
+ /// Verifies that rejects a null callback before
+ /// delegating to the underlying sync context.
+ ///
+ [Fact]
+ public void Send_WithNullCallback_Throws()
+ {
+ ThrowingSyncContext underlying = new();
+ NoMessagePumpSyncContext sc = new(underlying);
+
+ Assert.Throws(() => sc.Send(null!, null));
+ Assert.False(underlying.SendInvoked);
+ }
+
+ ///
+ /// Verifies that delegates to the underlying
+ /// sync context when one is provided.
+ ///
+ [Fact]
+ public async Task Post_WithUnderlyingContext_DelegatesToUnderlying()
+ {
+ TaskCompletionSource tcs = new();
+ RecordingPostSyncContext underlying = new(posted: _ => tcs.SetResult(true));
+ NoMessagePumpSyncContext sc = new(underlying);
+ sc.Post(_ => { }, null);
+ Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken));
+ }
+
+ ///
+ /// Verifies that delegates to the underlying
+ /// sync context when one is provided.
+ ///
+ [Fact]
+ public void Send_WithUnderlyingContext_DelegatesToUnderlying()
+ {
+ bool sendInvoked = false;
+ RecordingSendSyncContext underlying = new(sent: _ => sendInvoked = true);
+ NoMessagePumpSyncContext sc = new(underlying);
+ sc.Send(_ => { }, null);
+ Assert.True(sendInvoked);
+ }
+
#if NETFRAMEWORK
///
/// Establishes the baseline: on a plain STA thread without a special synchronization context,
@@ -77,4 +181,50 @@ public void Wait_BlocksComRpcCalls()
}
}
#endif
+
+ ///
+ /// A that invokes a callback when is called.
+ ///
+ private class RecordingPostSyncContext(Action posted) : SynchronizationContext
+ {
+ public override void Post(SendOrPostCallback d, object? state)
+ {
+ posted(d);
+ base.Post(d, state);
+ }
+ }
+
+ ///
+ /// A that invokes a callback when is called.
+ ///
+ private class RecordingSendSyncContext(Action sent) : SynchronizationContext
+ {
+ public override void Send(SendOrPostCallback d, object? state)
+ {
+ sent(d);
+ base.Send(d, state);
+ }
+ }
+
+ ///
+ /// A that records whether or were invoked.
+ ///
+ private class ThrowingSyncContext : SynchronizationContext
+ {
+ public bool PostInvoked { get; private set; }
+
+ public bool SendInvoked { get; private set; }
+
+ public override void Post(SendOrPostCallback d, object? state)
+ {
+ this.PostInvoked = true;
+ throw new InvalidOperationException();
+ }
+
+ public override void Send(SendOrPostCallback d, object? state)
+ {
+ this.SendInvoked = true;
+ throw new InvalidOperationException();
+ }
+ }
}