From 9e7e0ed9847ac1b779c26601942342b8a54e4c2a Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Tue, 25 Mar 2025 15:54:24 -0700 Subject: [PATCH] Fix `CancellationToken.Combine` with 3+ cancelable tokens --- .../CancellationTokenExtensions.cs | 5 ++++ .../CancellationTokenExtensionsTests.cs | 25 +++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/CancellationTokenExtensions.cs b/src/Microsoft.VisualStudio.Threading/CancellationTokenExtensions.cs index e63d09437..af6649e29 100644 --- a/src/Microsoft.VisualStudio.Threading/CancellationTokenExtensions.cs +++ b/src/Microsoft.VisualStudio.Threading/CancellationTokenExtensions.cs @@ -115,6 +115,11 @@ public static CombinedCancellationToken CombineWith(this CancellationToken origi // Before this point we've checked every condition that would allow us to avoid it. var cancelableTokens = new CancellationToken[cancelableTokensCount]; int i = 0; + if (original.CanBeCanceled) + { + cancelableTokens[i++] = original; + } + foreach (CancellationToken other in others) { if (other.CanBeCanceled) diff --git a/test/Microsoft.VisualStudio.Threading.Tests/CancellationTokenExtensionsTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/CancellationTokenExtensionsTests.cs index 9fb91ba53..5a957f05b 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/CancellationTokenExtensionsTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/CancellationTokenExtensionsTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Linq; using System.Threading; using Microsoft.VisualStudio.Threading; using Xunit; @@ -226,19 +227,23 @@ public void CombineWith_Array_TwoCancelable_AmidMany(bool cancelFirst) } } - [Fact] - public void CombineWith_Array_ThreeCancelable_AmidMany() + [Theory, CombinatorialData] + public void CombineWith_Array_ThreeCancelable_AmidMany([CombinatorialRange(0, 3)] int canceledIndex) { - var cts1 = new CancellationTokenSource(); - var cts2 = new CancellationTokenSource(); - var cts3 = new CancellationTokenSource(); - using (CancellationTokenExtensions.CombinedCancellationToken combined = CancellationToken.None.CombineWith(cts1.Token, CancellationToken.None, cts2.Token, CancellationToken.None, cts3.Token)) + CancellationTokenSource[] cts = new CancellationTokenSource[3]; + for (int i = 0; i < 3; i++) { - Assert.NotEqual(cts1.Token, combined.Token); - Assert.NotEqual(cts2.Token, combined.Token); - Assert.NotEqual(cts3.Token, combined.Token); + cts[i] = new(); + } - cts2.Cancel(); + using (CancellationTokenExtensions.CombinedCancellationToken combined = cts[0].Token.CombineWith(cts.Skip(1).Select(s => s.Token).ToArray())) + { + for (int i = 0; i < cts.Length; i++) + { + Assert.NotEqual(cts[i].Token, combined.Token); + } + + cts[canceledIndex].Cancel(); Assert.True(combined.Token.IsCancellationRequested); } }