Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/Microsoft.VisualStudio.Threading/NoMessagePumpSyncContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,33 @@ public class NoMessagePumpSyncContext : SynchronizationContext
/// </summary>
private static readonly SynchronizationContext DefaultInstance = new NoMessagePumpSyncContext();

private readonly SynchronizationContext? underlyingSyncContext;

/// <summary>
/// Initializes a new instance of the <see cref="NoMessagePumpSyncContext"/> class.
/// </summary>
/// <remarks>
/// When using this constructor, <see cref="Post"/> uses the default <see cref="SynchronizationContext"/>
/// behavior and schedules work on the thread pool, while <see cref="Send"/> uses the default
/// <see cref="SynchronizationContext"/> behavior and invokes the callback synchronously on the calling thread.
/// </remarks>
public NoMessagePumpSyncContext()
{
// This is required so that our override of Wait is invoked.
this.SetWaitNotificationRequired();
}

/// <summary>
/// Initializes a new instance of the <see cref="NoMessagePumpSyncContext"/> class.
/// </summary>
/// <param name="underlyingSyncContext">The <see cref="SynchronizationContext"/> that should handle calls to <see cref="Post"/> and <see cref="Send"/>.</param>
public NoMessagePumpSyncContext(SynchronizationContext underlyingSyncContext)
: this()
{
Requires.NotNull(underlyingSyncContext, nameof(underlyingSyncContext));
this.underlyingSyncContext = underlyingSyncContext;
}

/// <summary>
/// Gets a shared instance of this class.
/// </summary>
Expand All @@ -37,6 +55,36 @@ public static SynchronizationContext Default
get { return DefaultInstance; }
}

/// <inheritdoc/>
public override void Send(SendOrPostCallback d, object? state)
{
Requires.NotNull(d, nameof(d));

if (this.underlyingSyncContext is { } underlying)
{
underlying.Send(d, state);
Comment thread
AArnott marked this conversation as resolved.
}
Comment thread
AArnott marked this conversation as resolved.
else
{
base.Send(d, state);
}
}

/// <inheritdoc/>
public override void Post(SendOrPostCallback d, object? state)
{
Requires.NotNull(d, nameof(d));

if (this.underlyingSyncContext is { } underlying)
{
underlying.Post(d, state);
}
Comment thread
AArnott marked this conversation as resolved.
else
{
base.Post(d, state);
}
}

/// <summary>
/// Synchronously blocks without a message pump.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Threading;
using System.Threading.Tasks;

/// <summary>
/// Tests for <see cref="NoMessagePumpSyncContext"/>.
Expand Down Expand Up @@ -36,6 +37,109 @@ public void Default_IsNoMessagePumpSyncContext()
Assert.IsType<NoMessagePumpSyncContext>(NoMessagePumpSyncContext.Default);
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Post"/> schedules work on the thread pool
/// when no underlying sync context is provided.
/// </summary>
[Fact]
public async Task Post_DefaultConstructor_ExecutesOnThreadPool()
{
NoMessagePumpSyncContext sc = new();
TaskCompletionSource<bool> tcs = new();
sc.Post(_ => tcs.SetResult(Thread.CurrentThread.IsThreadPoolThread), null);
Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken));
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Send"/> executes work synchronously
/// on the calling thread when no underlying sync context is provided.
/// </summary>
[Fact]
public void Send_DefaultConstructor_ExecutesInlineOnCallingThread()
{
NoMessagePumpSyncContext sc = new();
int callingThreadId = Thread.CurrentThread.ManagedThreadId;
int? callbackThreadId = null;
bool callbackInvoked = false;

sc.Send(
_ =>
{
callbackInvoked = true;
callbackThreadId = Thread.CurrentThread.ManagedThreadId;
},
null);

Assert.True(callbackInvoked);
Assert.Equal(callingThreadId, callbackThreadId);
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext(SynchronizationContext)"/> throws
/// <see cref="ArgumentNullException"/> when a null underlying context is passed.
/// </summary>
[Fact]
public void Constructor_WithNullUnderlyingContext_Throws()
{
Assert.Throws<ArgumentNullException>(() => new NoMessagePumpSyncContext(null!));
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Post"/> rejects a null callback before
/// delegating to the underlying sync context.
/// </summary>
[Fact]
public void Post_WithNullCallback_Throws()
{
ThrowingSyncContext underlying = new();
NoMessagePumpSyncContext sc = new(underlying);

Assert.Throws<ArgumentNullException>(() => sc.Post(null!, null));
Assert.False(underlying.PostInvoked);
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Send"/> rejects a null callback before
/// delegating to the underlying sync context.
/// </summary>
[Fact]
public void Send_WithNullCallback_Throws()
{
ThrowingSyncContext underlying = new();
NoMessagePumpSyncContext sc = new(underlying);

Assert.Throws<ArgumentNullException>(() => sc.Send(null!, null));
Assert.False(underlying.SendInvoked);
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Post"/> delegates to the underlying
/// sync context when one is provided.
/// </summary>
[Fact]
public async Task Post_WithUnderlyingContext_DelegatesToUnderlying()
{
TaskCompletionSource<bool> tcs = new();
RecordingPostSyncContext underlying = new(posted: _ => tcs.SetResult(true));
NoMessagePumpSyncContext sc = new(underlying);
sc.Post(_ => { }, null);
Assert.True(await tcs.Task.WithCancellation(this.TimeoutToken));
}

/// <summary>
/// Verifies that <see cref="NoMessagePumpSyncContext.Send"/> delegates to the underlying
/// sync context when one is provided.
/// </summary>
[Fact]
public void Send_WithUnderlyingContext_DelegatesToUnderlying()
{
bool sendInvoked = false;
RecordingSendSyncContext underlying = new(sent: _ => sendInvoked = true);
NoMessagePumpSyncContext sc = new(underlying);
sc.Send(_ => { }, null);
Assert.True(sendInvoked);
}

#if NETFRAMEWORK
/// <summary>
/// Establishes the baseline: on a plain STA thread without a special synchronization context,
Expand Down Expand Up @@ -77,4 +181,50 @@ public void Wait_BlocksComRpcCalls()
}
}
#endif

/// <summary>
/// A <see cref="SynchronizationContext"/> that invokes a callback when <see cref="Post"/> is called.
/// </summary>
private class RecordingPostSyncContext(Action<SendOrPostCallback> posted) : SynchronizationContext
{
public override void Post(SendOrPostCallback d, object? state)
{
posted(d);
base.Post(d, state);
}
}

/// <summary>
/// A <see cref="SynchronizationContext"/> that invokes a callback when <see cref="Send"/> is called.
/// </summary>
private class RecordingSendSyncContext(Action<SendOrPostCallback> sent) : SynchronizationContext
{
public override void Send(SendOrPostCallback d, object? state)
{
sent(d);
base.Send(d, state);
}
}

/// <summary>
/// A <see cref="SynchronizationContext"/> that records whether <see cref="Post"/> or <see cref="Send"/> were invoked.
/// </summary>
private class ThrowingSyncContext : SynchronizationContext
{
public bool PostInvoked { get; private set; }

public bool SendInvoked { get; private set; }

public override void Post(SendOrPostCallback d, object? state)
{
this.PostInvoked = true;
throw new InvalidOperationException();
}

public override void Send(SendOrPostCallback d, object? state)
{
this.SendInvoked = true;
throw new InvalidOperationException();
}
}
}
Loading