diff --git a/src/GuardClauses/GuardAgainstExpressionExtensions.cs b/src/GuardClauses/GuardAgainstExpressionExtensions.cs index 3e13ec79..b145e0b0 100644 --- a/src/GuardClauses/GuardAgainstExpressionExtensions.cs +++ b/src/GuardClauses/GuardAgainstExpressionExtensions.cs @@ -1,4 +1,5 @@ using System; +using System.Threading.Tasks; namespace Ardalis.GuardClauses { @@ -23,5 +24,25 @@ public static T AgainstExpression(this IGuardClause guardClause, Func + /// Throws an if evaluates to false for given + /// + /// + /// + /// + /// + /// + /// if the evaluates to true + /// + public static async Task AgainstExpressionAsync([JetBrainsNotNull] this IGuardClause guardClause, [JetBrainsNotNull] Func> func, T input, string message) where T : struct + { + if (!await func(input)) + { + throw new ArgumentException(message); + } + + return input; + } } } diff --git a/src/GuardClauses/GuardAgainstInvalidFormatExtensions.cs b/src/GuardClauses/GuardAgainstInvalidFormatExtensions.cs index 6da0a291..78c1dd26 100644 --- a/src/GuardClauses/GuardAgainstInvalidFormatExtensions.cs +++ b/src/GuardClauses/GuardAgainstInvalidFormatExtensions.cs @@ -1,5 +1,6 @@ using System; using System.Text.RegularExpressions; +using System.Threading.Tasks; namespace Ardalis.GuardClauses { @@ -50,5 +51,26 @@ public static T InvalidInput(this IGuardClause guardClause, T input, string p return input; } + + /// + /// Throws an if doesn't satisfy the function. + /// + /// + /// + /// + /// + /// Optional. Custom error message + /// + /// + /// + public static async Task InvalidInputAsync([JetBrainsNotNull] this IGuardClause guardClause, [JetBrainsNotNull] T input, [JetBrainsNotNull][JetBrainsInvokerParameterName] string parameterName, Func> predicate, string? message = null) + { + if (!await predicate(input)) + { + throw new ArgumentException(message ?? $"Input {parameterName} did not satisfy the options", parameterName); + } + + return input; + } } } diff --git a/test/GuardClauses.UnitTests/GuardAgainstOutOfRangeForInvalidInput.cs b/test/GuardClauses.UnitTests/GuardAgainstOutOfRangeForInvalidInput.cs index 7448eeeb..f8fc8b36 100644 --- a/test/GuardClauses.UnitTests/GuardAgainstOutOfRangeForInvalidInput.cs +++ b/test/GuardClauses.UnitTests/GuardAgainstOutOfRangeForInvalidInput.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Threading.Tasks; using Ardalis.GuardClauses; using Microsoft.VisualBasic; using Xunit; @@ -16,43 +17,88 @@ public void DoesNothingGivenInRangeValue(T input, Func func) Guard.Against.InvalidInput(input, nameof(input), func); } - [Theory] - [ClassData(typeof(IncorrectClassData))] - public void ThrowsGivenOutOfRangeValue(T input, Func func) - { - Assert.Throws(() => Guard.Against.InvalidInput(input, nameof(input), func)); - } + [Theory] + [ClassData(typeof(CorrectAsyncClassData))] + public async Task DoesNothingGivenInRangeValueAsync(T input, Func> func) + { + await Guard.Against.InvalidInputAsync(input, nameof(input), func); + } - [Theory] - [ClassData(typeof(CorrectClassData))] - public void ReturnsExpectedValueGivenInRangeValue(T input, Func func) - { - var result = Guard.Against.InvalidInput(input, nameof(input), func); - Assert.Equal(input, result); - } + [Theory] + [ClassData(typeof(IncorrectClassData))] + public void ThrowsGivenOutOfRangeValue(T input, Func func) + { + Assert.Throws(() => Guard.Against.InvalidInput(input, nameof(input), func)); + } - [Theory] - [InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")] - [InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")] - public void ErrorMessageMatchesExpected(string customMessage, string expectedMessage) - { - var exception = Assert.Throws(() => Guard.Against.InvalidInput(10, "parameterName", x => x > 20, customMessage)); - Assert.NotNull(exception); - Assert.NotNull(exception.Message); - Assert.Equal(expectedMessage, exception.Message); - } + [Theory] + [ClassData(typeof(IncorrectAsyncClassData))] + public async Task ThrowsGivenOutOfRangeValueAsync(T input, Func> func) + { + await Assert.ThrowsAsync(async () => await Guard.Against.InvalidInputAsync(input, nameof(input), func)); + } - [Theory] - [InlineData(null, null)] - [InlineData(null, "Please provide correct value")] - [InlineData("SomeParameter", null)] - [InlineData("SomeOtherParameter", "Value must be correct")] - public void ExceptionParamNameMatchesExpected(string expectedParamName, string customMessage) - { - var exception = Assert.Throws(() => Guard.Against.InvalidInput(10, expectedParamName, x => x > 20, customMessage)); - Assert.NotNull(exception); - Assert.Equal(expectedParamName, exception.ParamName); - } + [Theory] + [ClassData(typeof(CorrectClassData))] + public void ReturnsExpectedValueGivenInRangeValue(T input, Func func) + { + var result = Guard.Against.InvalidInput(input, nameof(input), func); + Assert.Equal(input, result); + } + + [Theory] + [ClassData(typeof(CorrectAsyncClassData))] + public async Task ReturnsExpectedValueGivenInRangeValueAsync(T input, Func> func) + { + var result = await Guard.Against.InvalidInputAsync(input, nameof(input), func); + Assert.Equal(input, result); + } + + [Theory] + [InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")] + [InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")] + public void ErrorMessageMatchesExpected(string customMessage, string expectedMessage) + { + var exception = Assert.Throws(() => Guard.Against.InvalidInput(10, "parameterName", x => x > 20, customMessage)); + Assert.NotNull(exception); + Assert.NotNull(exception.Message); + Assert.Equal(expectedMessage, exception.Message); + } + + [Theory] + [InlineData(null, "Input parameterName did not satisfy the options (Parameter 'parameterName')")] + [InlineData("Evaluation failed", "Evaluation failed (Parameter 'parameterName')")] + public async Task ErrorMessageMatchesExpectedAsync(string customMessage, string expectedMessage) + { + var exception = await Assert.ThrowsAsync(async () => await Guard.Against.InvalidInputAsync(10, "parameterName", x => Task.FromResult(x > 20), customMessage)); + Assert.NotNull(exception); + Assert.NotNull(exception.Message); + Assert.Equal(expectedMessage, exception.Message); + } + + [Theory] + [InlineData(null, null)] + [InlineData(null, "Please provide correct value")] + [InlineData("SomeParameter", null)] + [InlineData("SomeOtherParameter", "Value must be correct")] + public void ExceptionParamNameMatchesExpected(string expectedParamName, string customMessage) + { + var exception = Assert.Throws(() => Guard.Against.InvalidInput(10, expectedParamName, x => x > 20, customMessage)); + Assert.NotNull(exception); + Assert.Equal(expectedParamName, exception.ParamName); + } + + [Theory] + [InlineData(null, null)] + [InlineData(null, "Please provide correct value")] + [InlineData("SomeParameter", null)] + [InlineData("SomeOtherParameter", "Value must be correct")] + public async Task ExceptionParamNameMatchesExpectedAsync(string expectedParamName, string customMessage) + { + var exception = await Assert.ThrowsAsync(async () => await Guard.Against.InvalidInputAsync(10, expectedParamName, x => Task.FromResult(x > 20), customMessage)); + Assert.NotNull(exception); + Assert.Equal(expectedParamName, exception.ParamName); + } // TODO: Test decimal types outside of ClassData // See: https://github.com/xunit/xunit/issues/2298 @@ -72,6 +118,21 @@ public IEnumerator GetEnumerator() IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } + public class CorrectAsyncClassData : IEnumerable { + public IEnumerator GetEnumerator() + { + yield return new object[] { 20, (Func>)((x) => Task.FromResult(x > 10)) }; + yield return new object[] { DateAndTime.Now, (Func>)((x) => Task.FromResult(x > DateTime.MinValue)) }; + yield return new object[] { 20.0f, (Func>)((x) => Task.FromResult(x > 10.0f)) }; + //yield return new object[] { 20.0m, (Func>)((x) => Task.FromResult(x > 10.0m)) }; + yield return new object[] { 20.0, (Func>)((x) => Task.FromResult(x > 10.0)) }; + yield return new object[] { long.MaxValue, (Func>)((x) => Task.FromResult(x > 1)) }; + yield return new object[] { short.MaxValue, (Func>)((x) => Task.FromResult(x > 1)) }; + yield return new object[] { "abcd", (Func>)((x) => Task.FromResult(x == x.ToLower())) }; + } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + public class IncorrectClassData : IEnumerable { public IEnumerator GetEnumerator() @@ -85,6 +146,22 @@ public IEnumerator GetEnumerator() yield return new object[] { short.MaxValue, (Func)((x) => x < 1) }; yield return new object[] { "abcd", (Func)((x) => x == x.ToUpper()) }; } - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class IncorrectAsyncClassData : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return new object[] { 20, (Func>)((x) => Task.FromResult(x < 10)) }; + yield return new object[] { DateAndTime.Now, (Func>)((x) => Task.FromResult(x > DateTime.MaxValue)) }; + yield return new object[] { 20.0f, (Func>)((x) => Task.FromResult(x > 30.0f)) }; + //yield return new object[] { 20.0m, (Func)((x) => x > 30.0m)) }; + yield return new object[] { 20.0, (Func>)((x) => Task.FromResult(x > 30.0)) }; + yield return new object[] { long.MaxValue, (Func>)((x) => Task.FromResult(x < 1)) }; + yield return new object[] { short.MaxValue, (Func>)((x) => Task.FromResult(x < 1)) }; + yield return new object[] { "abcd", (Func>)((x) => Task.FromResult(x == x.ToUpper())) }; + } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } }