From a37667c4539019bea2851136218f6b429ae7e257 Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Wed, 12 Aug 2020 20:34:28 -0700 Subject: [PATCH 1/8] implement ConnectAsync overloads with CancellationToken on Socket and TcpClient --- .../ref/System.Net.Sockets.cs | 7 ++ .../src/System/Net/Sockets/Socket.Tasks.cs | 50 +++++++- .../Net/Sockets/SocketTaskExtensions.cs | 8 ++ .../src/System/Net/Sockets/TCPClient.cs | 24 ++-- .../tests/FunctionalTests/Connect.cs | 117 ++++++++++++++++++ .../tests/FunctionalTests/SocketTestHelper.cs | 28 +++++ .../tests/FunctionalTests/TcpClientTest.cs | 16 +++ 7 files changed, 239 insertions(+), 11 deletions(-) diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 11d3309cf116df..e965b80cef0ea6 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -555,9 +555,13 @@ public static partial class SocketTaskExtensions public static System.Threading.Tasks.Task AcceptAsync(this System.Net.Sockets.Socket socket) { throw null; } public static System.Threading.Tasks.Task AcceptAsync(this System.Net.Sockets.Socket socket, System.Net.Sockets.Socket? acceptSocket) { throw null; } public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP) { throw null; } + public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken) { throw null; } public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port) { throw null; } + public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port) { throw null; } + public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public static System.Threading.Tasks.Task ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port) { throw null; } + public static System.Threading.Tasks.ValueTask ConnectAsync(this System.Net.Sockets.Socket socket, string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public static System.Threading.Tasks.Task ReceiveAsync(this System.Net.Sockets.Socket socket, System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags) { throw null; } public static System.Threading.Tasks.Task ReceiveAsync(this System.Net.Sockets.Socket socket, System.Collections.Generic.IList> buffers, System.Net.Sockets.SocketFlags socketFlags) { throw null; } public static System.Threading.Tasks.ValueTask ReceiveAsync(this System.Net.Sockets.Socket socket, System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -606,6 +610,9 @@ public void Connect(string hostname, int port) { } public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress address, int port) { throw null; } public System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress[] addresses, int port) { throw null; } public System.Threading.Tasks.Task ConnectAsync(string host, int port) { throw null; } + public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress address, int port, System.Threading.CancellationToken cancellationToken) { throw null; } + public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; } + public System.Threading.Tasks.ValueTask ConnectAsync(string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public void EndConnect(System.IAsyncResult asyncResult) { } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 75a226949d933e..6abc91637b3dc5 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -85,9 +85,38 @@ internal Task ConnectAsync(EndPoint remoteEP) return saea.ConnectAsync(this).AsTask(); } + internal async ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Use _singleBufferReceiveEventArgs so the AwaitableSocketAsyncEventArgs can be re-used later for receives. + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true); + + saea.RemoteEndPoint = remoteEP; + + try + { + using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea)) + { + await saea.ConnectAsync(this).ConfigureAwait(false); + } + } + catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted) + { + cancellationToken.ThrowIfCancellationRequested(); + throw; + } + } + internal Task ConnectAsync(IPAddress address, int port) => ConnectAsync(new IPEndPoint(address, port)); - internal Task ConnectAsync(IPAddress[] addresses, int port) + internal ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) => ConnectAsync(new IPEndPoint(address, port), cancellationToken); + + internal Task ConnectAsync(IPAddress[] addresses, int port) => ConnectAsync(addresses, port, CancellationToken.None).AsTask(); + + internal ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken) { if (addresses == null) { @@ -98,17 +127,17 @@ internal Task ConnectAsync(IPAddress[] addresses, int port) throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses)); } - return DoConnectAsync(addresses, port); + return DoConnectAsync(addresses, port, cancellationToken); } - private async Task DoConnectAsync(IPAddress[] addresses, int port) + private async ValueTask DoConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken) { Exception? lastException = null; foreach (IPAddress address in addresses) { try { - await ConnectAsync(address, port).ConfigureAwait(false); + await ConnectAsync(address, port, cancellationToken).ConfigureAwait(false); return; } catch (Exception ex) @@ -134,6 +163,19 @@ internal Task ConnectAsync(string host, int port) return ConnectAsync(ep); } + internal ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken) + { + if (host == null) + { + throw new ArgumentNullException(nameof(host)); + } + + EndPoint ep = IPAddress.TryParse(host, out IPAddress? parsedAddress) ? (EndPoint) + new IPEndPoint(parsedAddress, port) : + new DnsEndPoint(host, port); + return ConnectAsync(ep, cancellationToken); + } + internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool fromNetworkStream) { ValidateBuffer(buffer); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs index abaf3fcc52b8cf..f6a2243e626df3 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs @@ -16,12 +16,20 @@ public static Task AcceptAsync(this Socket socket, Socket? acceptSocket) public static Task ConnectAsync(this Socket socket, EndPoint remoteEP) => socket.ConnectAsync(remoteEP); + public static ValueTask ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken) => + socket.ConnectAsync(remoteEP, cancellationToken); public static Task ConnectAsync(this Socket socket, IPAddress address, int port) => socket.ConnectAsync(address, port); + public static ValueTask ConnectAsync(this Socket socket, IPAddress address, int port, CancellationToken cancellationToken) => + socket.ConnectAsync(address, port, cancellationToken); public static Task ConnectAsync(this Socket socket, IPAddress[] addresses, int port) => socket.ConnectAsync(addresses, port); + public static ValueTask ConnectAsync(this Socket socket, IPAddress[] addresses, int port, CancellationToken cancellationToken) => + socket.ConnectAsync(addresses, port, cancellationToken); public static Task ConnectAsync(this Socket socket, string host, int port) => socket.ConnectAsync(host, port); + public static ValueTask ConnectAsync(this Socket socket, string host, int port, CancellationToken cancellationToken) => + socket.ConnectAsync(host, port, cancellationToken); public static Task ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) => socket.ReceiveAsync(buffer, socketFlags, fromNetworkStream: false); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs index e70b1625a6c5a9..937d0d008b41e0 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPClient.cs @@ -259,13 +259,8 @@ public void Connect(IPAddress[] ipAddresses, int port) _active = true; } - public Task ConnectAsync(IPAddress address, int port) - { - - Task result = CompleteConnectAsync(Client.ConnectAsync(address, port)); - - return result; - } + public Task ConnectAsync(IPAddress address, int port) => + CompleteConnectAsync(Client.ConnectAsync(address, port)); public Task ConnectAsync(string host, int port) => CompleteConnectAsync(Client.ConnectAsync(host, port)); @@ -279,6 +274,21 @@ private async Task CompleteConnectAsync(Task task) _active = true; } + public ValueTask ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken) => + CompleteConnectAsync(Client.ConnectAsync(address, port, cancellationToken)); + + public ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken) => + CompleteConnectAsync(Client.ConnectAsync(host, port, cancellationToken)); + + public ValueTask ConnectAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken) => + CompleteConnectAsync(Client.ConnectAsync(addresses, port, cancellationToken)); + + private async ValueTask CompleteConnectAsync(ValueTask task) + { + await task.ConfigureAwait(false); + _active = true; + } + public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? requestCallback, object? state) => Client.BeginConnect(address, port, requestCallback, state); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 48fe236978c34e..64c86f097e4734 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -204,4 +204,121 @@ public sealed class ConnectEap : Connect { public ConnectEap(ITestOutputHelper output) : base(output) {} } + + public sealed class ConnectCancellableTask : Connect + { + public ConnectCancellableTask(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task ConnectEndPoint_Precanceled_Throws() + { + EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(ep, cts.Token)); + } + } + + [Fact] + public async Task ConnectAddressAndPort_Precanceled_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token)); + } + } + + [Fact] + public async Task ConnectMultiAddressAndPort_Precanceled_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token)); + } + } + + [Fact] + public async Task ConnectHostNameAndPort_Precanceled_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync("1.2.3.4", 1, cts.Token)); + } + } + + [Fact] + public async Task ConnectEndPoint_CancelDuringConnect_Throws() + { + EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); + + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + + ValueTask t = client.ConnectAsync(ep, cts.Token); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await t); + } + } + + [Fact] + public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + + ValueTask t = client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await t); + } + } + + [Fact] + public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + + ValueTask t = client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await t); + } + } + + [Fact] + public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() + { + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + + ValueTask t = client.ConnectAsync("1.2.3.4", 1, cts.Token); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await t); + } + } + } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 307c7c505f60b0..9a885e87e41f70 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Collections.Generic; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -179,6 +180,33 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo s.SendToAsync(buffer, SocketFlags.None, endPoint); } + // Same as above, but call the CancellationToken overloads where possible + public class SocketHelperCancellableTask : SocketHelperBase + { + public override Task AcceptAsync(Socket s) => + s.AcceptAsync(); + public override Task<(Socket socket, byte[] buffer)> AcceptAsync(Socket s, int receiveSize) + => throw new NotSupportedException(); + public override Task AcceptAsync(Socket s, Socket acceptSocket) => + s.AcceptAsync(acceptSocket); + public override Task ConnectAsync(Socket s, EndPoint endPoint) => + s.ConnectAsync(endPoint, CancellationToken.None).AsTask(); + public override Task MultiConnectAsync(Socket s, IPAddress[] addresses, int port) => + s.ConnectAsync(addresses, port, CancellationToken.None).AsTask(); + public override Task ReceiveAsync(Socket s, ArraySegment buffer) => + s.ReceiveAsync(buffer, SocketFlags.None, CancellationToken.None).AsTask(); + public override Task ReceiveAsync(Socket s, IList> bufferList) => + s.ReceiveAsync(bufferList, SocketFlags.None); + public override Task ReceiveFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => + s.ReceiveFromAsync(buffer, SocketFlags.None, endPoint); + public override Task SendAsync(Socket s, ArraySegment buffer) => + s.SendAsync(buffer, SocketFlags.None, CancellationToken.None).AsTask(); + public override Task SendAsync(Socket s, IList> bufferList) => + s.SendAsync(bufferList, SocketFlags.None); + public override Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => + s.SendToAsync(buffer, SocketFlags.None, endPoint); + } + public sealed class SocketHelperEap : SocketHelperBase { public override bool ValidatesArrayArguments => false; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpClientTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpClientTest.cs index c30fa492476726..a2ce1a28a66138 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpClientTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpClientTest.cs @@ -4,6 +4,7 @@ using Xunit; using Xunit.Abstractions; +using System.Threading; using System.Threading.Tasks; using System.Text; using System.Diagnostics; @@ -119,6 +120,9 @@ public void Ctor_StringInt_ConnectsSuccessfully() [InlineData(3)] [InlineData(4)] [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + [InlineData(8)] public async Task ConnectAsync_DnsEndPoint_Success(int mode) { using (var client = new DerivedTcpClient()) @@ -155,6 +159,18 @@ public async Task ConnectAsync_DnsEndPoint_Success(int mode) addresses = await Dns.GetHostAddressesAsync(host); await Task.Factory.FromAsync(client.BeginConnect, client.EndConnect, addresses, port, null); break; + + case 6: + await client.ConnectAsync(host, port, CancellationToken.None); + break; + case 7: + addresses = await Dns.GetHostAddressesAsync(host); + await client.ConnectAsync(addresses[0], port, CancellationToken.None); + break; + case 8: + addresses = await Dns.GetHostAddressesAsync(host); + await client.ConnectAsync(addresses, port, CancellationToken.None); + break; } Assert.True(client.Active); From 33e18da205acfa547877cafe8ae1231562220791 Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Thu, 13 Aug 2020 13:00:23 -0700 Subject: [PATCH 2/8] review feedback --- .../src/System/Net/Sockets/Socket.Tasks.cs | 2 +- .../tests/FunctionalTests/Connect.cs | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 6abc91637b3dc5..b62f1b657a9ce7 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -140,7 +140,7 @@ private async ValueTask DoConnectAsync(IPAddress[] addresses, int port, Cancella await ConnectAsync(address, port, cancellationToken).ConfigureAwait(false); return; } - catch (Exception ex) + catch (Exception ex) when (ex is not OperationCanceledException) { lastException = ex; } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 64c86f097e4734..08ecf15f3b925b 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -260,6 +260,8 @@ public async Task ConnectHostNameAndPort_Precanceled_Throws() } [Fact] + [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectEndPoint_CancelDuringConnect_Throws() { EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); @@ -270,13 +272,16 @@ public async Task ConnectEndPoint_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(ep, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); await Assert.ThrowsAnyAsync(async () => await t); } } [Fact] + [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() { using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -285,13 +290,16 @@ public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); await Assert.ThrowsAnyAsync(async () => await t); } } [Fact] + [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() { using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -300,13 +308,16 @@ public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); await Assert.ThrowsAnyAsync(async () => await t); } } [Fact] + [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() { using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -315,7 +326,8 @@ public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync("1.2.3.4", 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); await Assert.ThrowsAnyAsync(async () => await t); } From 51dad023200770ada4b2c7e67fc97f96e2188eff Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Thu, 13 Aug 2020 13:24:31 -0700 Subject: [PATCH 3/8] rework to avoid code duplication and handle sync completion better --- .../src/System/Net/Sockets/Socket.Tasks.cs | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index b62f1b657a9ce7..83b8e3e4bb737e 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -74,20 +74,14 @@ internal Task AcceptAsync(Socket? acceptSocket) return t; } - internal Task ConnectAsync(EndPoint remoteEP) - { - // Use _singleBufferReceiveEventArgs so the AwaitableSocketAsyncEventArgs can be re-used later for receives. - AwaitableSocketAsyncEventArgs saea = - Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ?? - new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true); + internal Task ConnectAsync(EndPoint remoteEP) => ConnectAsync(remoteEP, default).AsTask(); - saea.RemoteEndPoint = remoteEP; - return saea.ConnectAsync(this).AsTask(); - } - - internal async ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken) + internal ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } // Use _singleBufferReceiveEventArgs so the AwaitableSocketAsyncEventArgs can be re-used later for receives. AwaitableSocketAsyncEventArgs saea = @@ -96,18 +90,34 @@ internal async ValueTask ConnectAsync(EndPoint remoteEP, CancellationToken cance saea.RemoteEndPoint = remoteEP; - try + ValueTask connectTask = saea.ConnectAsync(this); + if (connectTask.IsCompleted || !cancellationToken.CanBeCanceled) { - using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea)) - { - await saea.ConnectAsync(this).ConfigureAwait(false); - } + // Avoid async invocation overhead + return connectTask; } - catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted) + else { - cancellationToken.ThrowIfCancellationRequested(); - throw; + return WaitForConnectWithCancellation(saea, connectTask, cancellationToken); + } + + async ValueTask WaitForConnectWithCancellation(AwaitableSocketAsyncEventArgs saea, ValueTask connectTask, CancellationToken cancellationToken) + { + Debug.Assert(cancellationToken.CanBeCanceled); + try + { + using (cancellationToken.UnsafeRegister(o => CancelConnectAsync((SocketAsyncEventArgs)o!), saea)) + { + await connectTask.ConfigureAwait(false); + } + } + catch (SocketException se) when (se.SocketErrorCode == SocketError.OperationAborted) + { + cancellationToken.ThrowIfCancellationRequested(); + throw; + } } + } internal Task ConnectAsync(IPAddress address, int port) => ConnectAsync(new IPEndPoint(address, port)); @@ -150,18 +160,7 @@ private async ValueTask DoConnectAsync(IPAddress[] addresses, int port, Cancella ExceptionDispatchInfo.Throw(lastException); } - internal Task ConnectAsync(string host, int port) - { - if (host == null) - { - throw new ArgumentNullException(nameof(host)); - } - - EndPoint ep = IPAddress.TryParse(host, out IPAddress? parsedAddress) ? (EndPoint) - new IPEndPoint(parsedAddress, port) : - new DnsEndPoint(host, port); - return ConnectAsync(ep); - } + internal Task ConnectAsync(string host, int port) => ConnectAsync(host, port, default).AsTask(); internal ValueTask ConnectAsync(string host, int port, CancellationToken cancellationToken) { From ed69135ee9f9b70ebd35fa59886fea6659f47d15 Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Thu, 13 Aug 2020 15:17:17 -0700 Subject: [PATCH 4/8] fix warning introduced by new APIs --- .../Http/aspnetcore/Quic/Implementations/Mock/MockConnection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/Common/src/System/Net/Http/aspnetcore/Quic/Implementations/Mock/MockConnection.cs b/src/libraries/Common/src/System/Net/Http/aspnetcore/Quic/Implementations/Mock/MockConnection.cs index cba2f936ef8dd2..b97f4876bfa7f2 100644 --- a/src/libraries/Common/src/System/Net/Http/aspnetcore/Quic/Implementations/Mock/MockConnection.cs +++ b/src/libraries/Common/src/System/Net/Http/aspnetcore/Quic/Implementations/Mock/MockConnection.cs @@ -74,7 +74,7 @@ internal override async ValueTask ConnectAsync(CancellationToken cancellationTok } Socket socket = new Socket(_remoteEndPoint!.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - await socket.ConnectAsync(_remoteEndPoint).ConfigureAwait(false); + await socket.ConnectAsync(_remoteEndPoint, cancellationToken).ConfigureAwait(false); socket.NoDelay = true; _localEndPoint = (IPEndPoint?)socket.LocalEndPoint; From f93dcb1a8af83a6cbe9022e8d361b118bd1a351a Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Fri, 14 Aug 2020 19:33:14 -0700 Subject: [PATCH 5/8] improve tests per PR feedback --- .../tests/FunctionalTests/Connect.cs | 147 ++++++++++-------- .../tests/FunctionalTests/SocketTestHelper.cs | 11 +- 2 files changed, 93 insertions(+), 65 deletions(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 08ecf15f3b925b..bc41320949ee79 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -212,107 +212,128 @@ public ConnectCancellableTask(ITestOutputHelper output) : base(output) { } [Fact] public async Task ConnectEndPoint_Precanceled_Throws() { - EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); - cts.Cancel(); + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(ep, cts.Token)); - } + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(remoteEndPoint, cts.Token)); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] public async Task ConnectAddressAndPort_Precanceled_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); - cts.Cancel(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token)); - } + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(remoteEndPoint.Address, remoteEndPoint.Port, cts.Token)); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] public async Task ConnectMultiAddressAndPort_Precanceled_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); - cts.Cancel(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token)); - } + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new IPAddress[] { remoteEndPoint.Address, remoteEndPoint.Address }, remoteEndPoint.Port, cts.Token)); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] public async Task ConnectHostNameAndPort_Precanceled_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); - cts.Cancel(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync("1.2.3.4", 1, cts.Token)); - } + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync("127.0.0.1", remoteEndPoint.Port, cts.Token)); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] [OuterLoop("Uses Task.Delay")] - [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectEndPoint_CancelDuringConnect_Throws() { - EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - ValueTask t = client.ConnectAsync(ep, cts.Token); + using CancellationTokenSource cts = new CancellationTokenSource(); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + ValueTask t = client.ConnectAsync(remoteEndPoint, cts.Token); - await Assert.ThrowsAnyAsync(async () => await t); - } + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] [OuterLoop("Uses Task.Delay")] - [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - ValueTask t = client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token); + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + using CancellationTokenSource cts = new CancellationTokenSource(); - await Assert.ThrowsAnyAsync(async () => await t); - } + ValueTask t = client.ConnectAsync(remoteEndPoint.Address, remoteEndPoint.Port, cts.Token); + + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] [OuterLoop("Uses Task.Delay")] - [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - ValueTask t = client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token); + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + using CancellationTokenSource cts = new CancellationTokenSource(); - await Assert.ThrowsAnyAsync(async () => await t); - } + ValueTask t = client.ConnectAsync(new IPAddress[] { remoteEndPoint.Address, remoteEndPoint.Address}, remoteEndPoint.Port, cts.Token); + + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); + Assert.Equal(cts.Token, e.CancellationToken); } [Fact] @@ -320,17 +341,21 @@ public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() { - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - var cts = new CancellationTokenSource(); + using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - ValueTask t = client.ConnectAsync("1.2.3.4", 1, cts.Token); + using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + using CancellationTokenSource cts = new CancellationTokenSource(); - await Assert.ThrowsAnyAsync(async () => await t); - } + ValueTask t = client.ConnectAsync("127.0.0.1", remoteEndPoint.Port, cts.Token); + + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); + + OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); + Assert.Equal(cts.Token, e.CancellationToken); } } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 9a885e87e41f70..41fa116b0c40c9 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -183,6 +183,9 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo // Same as above, but call the CancellationToken overloads where possible public class SocketHelperCancellableTask : SocketHelperBase { + // Use a cancellable CancellationToken that we never cancel so that implementations can't just elide handling the CancellationToken. + private readonly CancellationTokenSource _cts = new CancellationTokenSource(); + public override Task AcceptAsync(Socket s) => s.AcceptAsync(); public override Task<(Socket socket, byte[] buffer)> AcceptAsync(Socket s, int receiveSize) @@ -190,17 +193,17 @@ public override Task AcceptAsync(Socket s) => public override Task AcceptAsync(Socket s, Socket acceptSocket) => s.AcceptAsync(acceptSocket); public override Task ConnectAsync(Socket s, EndPoint endPoint) => - s.ConnectAsync(endPoint, CancellationToken.None).AsTask(); + s.ConnectAsync(endPoint, _cts.Token).AsTask(); public override Task MultiConnectAsync(Socket s, IPAddress[] addresses, int port) => - s.ConnectAsync(addresses, port, CancellationToken.None).AsTask(); + s.ConnectAsync(addresses, port, _cts.Token).AsTask(); public override Task ReceiveAsync(Socket s, ArraySegment buffer) => - s.ReceiveAsync(buffer, SocketFlags.None, CancellationToken.None).AsTask(); + s.ReceiveAsync(buffer, SocketFlags.None, _cts.Token).AsTask(); public override Task ReceiveAsync(Socket s, IList> bufferList) => s.ReceiveAsync(bufferList, SocketFlags.None); public override Task ReceiveFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => s.ReceiveFromAsync(buffer, SocketFlags.None, endPoint); public override Task SendAsync(Socket s, ArraySegment buffer) => - s.SendAsync(buffer, SocketFlags.None, CancellationToken.None).AsTask(); + s.SendAsync(buffer, SocketFlags.None, _cts.Token).AsTask(); public override Task SendAsync(Socket s, IList> bufferList) => s.SendAsync(bufferList, SocketFlags.None); public override Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => From 0aa6f163efc9c0b91fa6f831a25b6e741afdeb15 Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Fri, 14 Aug 2020 19:42:53 -0700 Subject: [PATCH 6/8] modify SocketConnectionFactory to use the new socket ConnectAsync method --- .../src/System.Net.Connections.csproj | 1 - .../Sockets/SocketsConnectionFactory.cs | 22 +------------ .../Sockets/TaskSocketAsyncEventArgs.cs | 31 ------------------- 3 files changed, 1 insertion(+), 53 deletions(-) delete mode 100644 src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs diff --git a/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj b/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj index 7ccd9e6f600ab5..cf6682f2cd821e 100644 --- a/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj +++ b/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj @@ -18,7 +18,6 @@ - diff --git a/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs index 2936a04e0484cc..732ad9447828a5 100644 --- a/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs @@ -64,27 +64,7 @@ public override async ValueTask ConnectAsync( try { - using var args = new TaskSocketAsyncEventArgs(); - args.RemoteEndPoint = endPoint; - - if (socket.ConnectAsync(args)) - { - using (cancellationToken.UnsafeRegister(static o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args)) - { - await args.Task.ConfigureAwait(false); - } - } - - if (args.SocketError != SocketError.Success) - { - if (args.SocketError == SocketError.OperationAborted) - { - cancellationToken.ThrowIfCancellationRequested(); - } - - throw NetworkErrorHelper.MapSocketException(new SocketException((int)args.SocketError)); - } - + await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); return new SocketConnection(socket); } catch (SocketException socketException) diff --git a/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs deleted file mode 100644 index a1bb69d3f0501e..00000000000000 --- a/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Net.Sockets; -using System.Threading.Tasks; -using System.Threading.Tasks.Sources; - -namespace System.Net.Connections -{ - internal sealed class TaskSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource - { - private ManualResetValueTaskSourceCore _valueTaskSource; - - public void ResetTask() => _valueTaskSource.Reset(); - public ValueTask Task => new ValueTask(this, _valueTaskSource.Version); - - public void GetResult(short token) => _valueTaskSource.GetResult(token); - public ValueTaskSourceStatus GetStatus(short token) => _valueTaskSource.GetStatus(token); - public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _valueTaskSource.OnCompleted(continuation, state, token, flags); - - public TaskSocketAsyncEventArgs() - : base(unsafeSuppressExecutionContextFlow: true) - { - } - - protected override void OnCompleted(SocketAsyncEventArgs e) - { - _valueTaskSource.SetResult(0); - } - } -} From 75dabcae3ba3e64626a8fb15066e0a8b263e81cf Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Fri, 14 Aug 2020 22:31:06 -0700 Subject: [PATCH 7/8] more test improvements --- .../tests/FunctionalTests/Connect.cs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index bc41320949ee79..ab5ca97a8eb5fe 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -287,8 +287,7 @@ public async Task ConnectEndPoint_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(remoteEndPoint, cts.Token); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + cts.Cancel(); OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); Assert.Equal(cts.Token, e.CancellationToken); @@ -308,8 +307,7 @@ public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(remoteEndPoint.Address, remoteEndPoint.Port, cts.Token); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + cts.Cancel(); OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); Assert.Equal(cts.Token, e.CancellationToken); @@ -329,8 +327,7 @@ public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync(new IPAddress[] { remoteEndPoint.Address, remoteEndPoint.Address}, remoteEndPoint.Port, cts.Token); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + cts.Cancel(); OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); Assert.Equal(cts.Token, e.CancellationToken); @@ -338,7 +335,6 @@ public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() [Fact] [OuterLoop("Uses Task.Delay")] - [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() { using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); @@ -351,8 +347,7 @@ public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() ValueTask t = client.ConnectAsync("127.0.0.1", remoteEndPoint.Port, cts.Token); - // Delay cancellation a bit to try to ensure the OS actually attempts to connect - cts.CancelAfter(100); + cts.Cancel(); OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); Assert.Equal(cts.Token, e.CancellationToken); From 1ead0cb8930d9ed0bc92a39144951b73ad26a0df Mon Sep 17 00:00:00 2001 From: Geoffrey Kizer Date: Sun, 16 Aug 2020 01:35:13 -0700 Subject: [PATCH 8/8] revert last test changes --- .../tests/FunctionalTests/Connect.cs | 144 ++++++++---------- 1 file changed, 62 insertions(+), 82 deletions(-) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index ab5ca97a8eb5fe..08ecf15f3b925b 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -212,145 +212,125 @@ public ConnectCancellableTask(ITestOutputHelper output) : base(output) { } [Fact] public async Task ConnectEndPoint_Precanceled_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; + EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); - cts.Cancel(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(remoteEndPoint, cts.Token)); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(ep, cts.Token)); + } } [Fact] public async Task ConnectAddressAndPort_Precanceled_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); - cts.Cancel(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(remoteEndPoint.Address, remoteEndPoint.Port, cts.Token)); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token)); + } } [Fact] public async Task ConnectMultiAddressAndPort_Precanceled_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); - cts.Cancel(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new IPAddress[] { remoteEndPoint.Address, remoteEndPoint.Address }, remoteEndPoint.Port, cts.Token)); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token)); + } } [Fact] public async Task ConnectHostNameAndPort_Precanceled_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); - cts.Cancel(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync("127.0.0.1", remoteEndPoint.Port, cts.Token)); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync("1.2.3.4", 1, cts.Token)); + } } [Fact] [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectEndPoint_CancelDuringConnect_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + EndPoint ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1); - using CancellationTokenSource cts = new CancellationTokenSource(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); - ValueTask t = client.ConnectAsync(remoteEndPoint, cts.Token); + ValueTask t = client.ConnectAsync(ep, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await t); + } } [Fact] [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectAddressAndPort_CancelDuringConnect_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); - ValueTask t = client.ConnectAsync(remoteEndPoint.Address, remoteEndPoint.Port, cts.Token); + ValueTask t = client.ConnectAsync(IPAddress.Parse("1.2.3.4"), 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await t); + } } [Fact] [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectMultiAddressAndPort_CancelDuringConnect_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); - ValueTask t = client.ConnectAsync(new IPAddress[] { remoteEndPoint.Address, remoteEndPoint.Address}, remoteEndPoint.Port, cts.Token); + ValueTask t = client.ConnectAsync(new IPAddress[] { IPAddress.Parse("1.2.3.4"), IPAddress.Parse("1.2.3.5") }, 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await t); + } } [Fact] [OuterLoop("Uses Task.Delay")] + [PlatformSpecific(TestPlatforms.Windows)] // Linux will not even attempt to connect to the invalid IP address public async Task ConnectHostNameAndPort_CancelDuringConnect_Throws() { - using Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - listen.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - IPEndPoint remoteEndPoint = (IPEndPoint)listen.LocalEndPoint; - - using Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - - using CancellationTokenSource cts = new CancellationTokenSource(); + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + var cts = new CancellationTokenSource(); - ValueTask t = client.ConnectAsync("127.0.0.1", remoteEndPoint.Port, cts.Token); + ValueTask t = client.ConnectAsync("1.2.3.4", 1, cts.Token); - cts.Cancel(); + // Delay cancellation a bit to try to ensure the OS actually attempts to connect + cts.CancelAfter(100); - OperationCanceledException e = await Assert.ThrowsAnyAsync(async () => await t); - Assert.Equal(cts.Token, e.CancellationToken); + await Assert.ThrowsAnyAsync(async () => await t); + } } } }