diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs index 13eadbffb4580e..4ae10a2fe866e8 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs @@ -3,6 +3,7 @@ using Internal.Reflection.Core.Execution; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Reflection.Runtime.MethodInfos; using static System.Reflection.DynamicInvokeInfo; @@ -11,11 +12,13 @@ namespace System.Reflection public sealed class ConstructorInvoker { private readonly MethodBaseInvoker _methodBaseInvoker; + private readonly int _parameterCount; private readonly RuntimeTypeHandle _declaringTypeHandle; internal ConstructorInvoker(RuntimeConstructorInfo constructor) { _methodBaseInvoker = constructor.MethodInvoker; + _parameterCount = constructor.GetParametersNoCopy().Length; _declaringTypeHandle = constructor.DeclaringType.TypeHandle; } @@ -32,6 +35,11 @@ public static ConstructorInvoker Create(ConstructorInfo constructor) [DebuggerGuidedStepThrough] public object Invoke() { + if (_parameterCount != 0) + { + ThrowForArgCountMismatch(); + } + object result = _methodBaseInvoker.CreateInstanceWithFewArgs(new Span()); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; @@ -40,7 +48,12 @@ public object Invoke() [DebuggerGuidedStepThrough] public object Invoke(object? arg1) { - object result = _methodBaseInvoker.CreateInstanceWithFewArgs(new Span(ref arg1)); + if (_parameterCount != 1) + { + ThrowForArgCountMismatch(); + } + + object result = _methodBaseInvoker.CreateInstanceWithFewArgs(new Span(ref arg1, _parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -48,10 +61,15 @@ public object Invoke(object? arg1) [DebuggerGuidedStepThrough] public object Invoke(object? arg1, object? arg2) { + if (_parameterCount != 2) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); - object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(2)); + object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -59,11 +77,16 @@ public object Invoke(object? arg1, object? arg2) [DebuggerGuidedStepThrough] public object Invoke(object? arg1, object? arg2, object? arg3) { + if (_parameterCount != 3) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); argStorage._args.Set(2, arg3); - object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(3)); + object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -71,12 +94,17 @@ public object Invoke(object? arg1, object? arg2, object? arg3) [DebuggerGuidedStepThrough] public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) { + if (_parameterCount != 4) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); argStorage._args.Set(2, arg3); argStorage._args.Set(3, arg4); - object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(4)); + object result = _methodBaseInvoker.CreateInstanceWithFewArgs(argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -88,5 +116,11 @@ public object Invoke(Span arguments) DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } + + [DoesNotReturn] + private static void ThrowForArgCountMismatch() + { + throw new TargetParameterCountException(SR.Arg_ParmCnt); + } } } diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs index 82951425e37093..c200a702c7a430 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs @@ -3,6 +3,7 @@ using Internal.Reflection.Core.Execution; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Reflection.Runtime.MethodInfos; using static System.Reflection.DynamicInvokeInfo; @@ -11,10 +12,12 @@ namespace System.Reflection public sealed class MethodInvoker { private readonly MethodBaseInvoker _methodBaseInvoker; + private readonly int _parameterCount; internal MethodInvoker(RuntimeMethodInfo method) { _methodBaseInvoker = method.MethodInvoker; + _parameterCount = method.GetParametersNoCopy().Length; } internal MethodInvoker(RuntimeConstructorInfo constructor) @@ -44,6 +47,11 @@ public static MethodInvoker Create(MethodBase method) [DebuggerGuidedStepThrough] public object? Invoke(object? obj) { + if (_parameterCount != 0) + { + ThrowForArgCountMismatch(); + } + object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, new Span()); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; @@ -52,7 +60,12 @@ public static MethodInvoker Create(MethodBase method) [DebuggerGuidedStepThrough] public object? Invoke(object? obj, object? arg1) { - object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, new Span(ref arg1)); + if (_parameterCount != 1) + { + ThrowForArgCountMismatch(); + } + + object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, new Span(ref arg1, _parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -60,11 +73,16 @@ public static MethodInvoker Create(MethodBase method) [DebuggerGuidedStepThrough] public object? Invoke(object? obj, object? arg1, object? arg2) { + if (_parameterCount != 2) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); - object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(2)); + object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -72,12 +90,17 @@ public static MethodInvoker Create(MethodBase method) [DebuggerGuidedStepThrough] public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3) { + if (_parameterCount != 3) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); argStorage._args.Set(2, arg3); - object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(3)); + object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -85,13 +108,18 @@ public static MethodInvoker Create(MethodBase method) [DebuggerGuidedStepThrough] public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) { + if (_parameterCount != 4) + { + ThrowForArgCountMismatch(); + } + StackAllocatedArguments argStorage = default; argStorage._args.Set(0, arg1); argStorage._args.Set(1, arg2); argStorage._args.Set(2, arg3); argStorage._args.Set(3, arg4); - object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(4)); + object? result = _methodBaseInvoker.InvokeDirectWithFewArgs(obj, argStorage._args.AsSpan(_parameterCount)); DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } @@ -103,5 +131,11 @@ public static MethodInvoker Create(MethodBase method) DebugAnnotations.PreviousCallContainsDebuggerStepInCode(); return result; } + + [DoesNotReturn] + private static void ThrowForArgCountMismatch() + { + throw new TargetParameterCountException(SR.Arg_ParmCnt); + } } } diff --git a/src/libraries/Common/src/Extensions/ParameterDefaultValue/ParameterDefaultValue.netstandard.cs b/src/libraries/Common/src/Extensions/ParameterDefaultValue/ParameterDefaultValue.netstandard.cs index 089c64afe03b00..78486aa1e37ef8 100644 --- a/src/libraries/Common/src/Extensions/ParameterDefaultValue/ParameterDefaultValue.netstandard.cs +++ b/src/libraries/Common/src/Extensions/ParameterDefaultValue/ParameterDefaultValue.netstandard.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Diagnostics.CodeAnalysis; using System.Reflection; -using System.Runtime.Serialization; namespace Microsoft.Extensions.Internal { diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs index 43e8f5627a2d28..c4258ca7877428 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs @@ -17,6 +17,11 @@ namespace Microsoft.Extensions.DependencyInjection /// public static class ActivatorUtilities { +#if NET8_0_OR_GREATER + // Maximum number of fixed arguments for ConstructorInvoker.Invoke(arg1, etc). + private const int FixedArgumentThreshold = 4; +#endif + private static readonly MethodInfo GetServiceInfo = GetMethodInfo>((sp, t, r, c) => GetService(sp, t, r, c)); @@ -140,7 +145,6 @@ public static ObjectFactory CreateFactory( return CreateFactoryReflection(instanceType, argumentTypes); } #endif - CreateFactoryInternal(instanceType, argumentTypes, out ParameterExpression provider, out ParameterExpression argumentArray, out Expression factoryExpressionBody); var factoryLambda = Expression.Lambda>( @@ -174,7 +178,6 @@ public static ObjectFactory return (serviceProvider, arguments) => (T)factory(serviceProvider, arguments); } #endif - CreateFactoryInternal(typeof(T), argumentTypes, out ParameterExpression provider, out ParameterExpression argumentArray, out Expression factoryExpressionBody); var factoryLambda = Expression.Lambda>( @@ -235,16 +238,22 @@ private static MethodInfo GetMethodInfo(Expression expr) return mc.Method; } - private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool isDefaultParameterRequired) + private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue) { object? service = sp.GetService(type); - if (service == null && !isDefaultParameterRequired) + if (service is null && !hasDefaultValue) { - throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, type, requiredBy)); + ThrowHelperUnableToResolveService(type, requiredBy); } return service; } + [DoesNotReturn] + private static void ThrowHelperUnableToResolveService(Type type, Type requiredBy) + { + throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, type, requiredBy)); + } + private static BlockExpression BuildFactoryExpression( ConstructorInfo constructor, int?[] parameterMap, @@ -289,53 +298,114 @@ private static BlockExpression BuildFactoryExpression( } #if NETSTANDARD2_1_OR_GREATER || NETCOREAPP + [DoesNotReturn] + private static void ThrowHelperArgumentNullExceptionServiceProvider() + { + throw new ArgumentNullException("serviceProvider"); + } + private static ObjectFactory CreateFactoryReflection( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType, Type?[] argumentTypes) { FindApplicableConstructor(instanceType, argumentTypes, out ConstructorInfo constructor, out int?[] parameterMap); + Type declaringType = constructor.DeclaringType!; + +#if NET8_0_OR_GREATER + ConstructorInvoker invoker = ConstructorInvoker.Create(constructor); ParameterInfo[] constructorParameters = constructor.GetParameters(); if (constructorParameters.Length == 0) { return (IServiceProvider serviceProvider, object?[]? arguments) => - constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: null, culture: null); + invoker.Invoke(); } - FactoryParameterContext[] parameters = new FactoryParameterContext[constructorParameters.Length]; + // Gather some metrics to determine what fast path to take, if any. + bool useFixedValues = constructorParameters.Length <= FixedArgumentThreshold; + bool hasAnyDefaultValues = false; + int matchedArgCount = 0; + int matchedArgCountWithMap = 0; for (int i = 0; i < constructorParameters.Length; i++) { - ParameterInfo constructorParameter = constructorParameters[i]; - bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue); + hasAnyDefaultValues |= constructorParameters[i].HasDefaultValue; - parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1); + if (parameterMap[i] is not null) + { + matchedArgCount++; + if (parameterMap[i] == i) + { + matchedArgCountWithMap++; + } + } } - Type declaringType = constructor.DeclaringType!; - return (IServiceProvider serviceProvider, object?[]? arguments) => + // No fast path; contains default values or arg mapping. + if (hasAnyDefaultValues || matchedArgCount != matchedArgCountWithMap) { - if (serviceProvider is null) + return InvokeCanonical(); + } + + if (matchedArgCount == 0) + { + // All injected; use a fast path. + Type[] types = GetParameterTypes(); + return useFixedValues ? + (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) : + (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider); + } + + if (matchedArgCount == constructorParameters.Length) + { + // All direct with no mappings; use a fast path. + return (serviceProvider, arguments) => ReflectionFactoryDirect(invoker, serviceProvider, arguments); + } + + return InvokeCanonical(); + + ObjectFactory InvokeCanonical() + { + FactoryParameterContext[] parameters = GetFactoryParameterContext(); + return useFixedValues ? + (serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) : + (serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments); + } + + Type[] GetParameterTypes() + { + Type[] types = new Type[constructorParameters.Length]; + for (int i = 0; i < constructorParameters.Length; i++) { - throw new ArgumentNullException(nameof(serviceProvider)); + types[i] = constructorParameters[i].ParameterType; } + return types; + } +#else + ParameterInfo[] constructorParameters = constructor.GetParameters(); + if (constructorParameters.Length == 0) + { + return (IServiceProvider serviceProvider, object?[]? arguments) => + constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: null, culture: null); + } + + FactoryParameterContext[] parameters = GetFactoryParameterContext(); + return (serviceProvider, arguments) => ReflectionFactoryCanonical(constructor, parameters, declaringType, serviceProvider, arguments); +#endif // NET8_0_OR_GREATER - object?[] constructorArguments = new object?[parameters.Length]; - for (int i = 0; i < parameters.Length; i++) + FactoryParameterContext[] GetFactoryParameterContext() + { + FactoryParameterContext[] parameters = new FactoryParameterContext[constructorParameters.Length]; + for (int i = 0; i < constructorParameters.Length; i++) { - ref FactoryParameterContext parameter = ref parameters[i]; - constructorArguments[i] = ((parameter.ArgumentIndex != -1) - // Throws an NullReferenceException if arguments is null. Consistent with expression-based factory. - ? arguments![parameter.ArgumentIndex] - : GetService( - serviceProvider, - parameter.ParameterType, - declaringType, - parameter.HasDefaultValue)) ?? parameter.DefaultValue; + ParameterInfo constructorParameter = constructorParameters[i]; + bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue); + parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1); } - return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null); - }; + return parameters; + } } +#endif // NETSTANDARD2_1_OR_GREATER || NETCOREAPP private readonly struct FactoryParameterContext { @@ -352,7 +422,6 @@ public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? public object? DefaultValue { get; } public int ArgumentIndex { get; } } -#endif private static void FindApplicableConstructor( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType, @@ -360,11 +429,11 @@ private static void FindApplicableConstructor( out ConstructorInfo matchingConstructor, out int?[] matchingParameterMap) { - ConstructorInfo? constructorInfo = null; - int?[]? parameterMap = null; + ConstructorInfo? constructorInfo; + int?[]? parameterMap; - if (!TryFindPreferredConstructor(instanceType, argumentTypes, ref constructorInfo, ref parameterMap) && - !TryFindMatchingConstructor(instanceType, argumentTypes, ref constructorInfo, ref parameterMap)) + if (!TryFindPreferredConstructor(instanceType, argumentTypes, out constructorInfo, out parameterMap) && + !TryFindMatchingConstructor(instanceType, argumentTypes, out constructorInfo, out parameterMap)) { throw new InvalidOperationException(SR.Format(SR.CtorNotLocated, instanceType)); } @@ -377,9 +446,12 @@ private static void FindApplicableConstructor( private static bool TryFindMatchingConstructor( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType, Type?[] argumentTypes, - [NotNullWhen(true)] ref ConstructorInfo? matchingConstructor, - [NotNullWhen(true)] ref int?[]? parameterMap) + [NotNullWhen(true)] out ConstructorInfo? matchingConstructor, + [NotNullWhen(true)] out int?[]? parameterMap) { + matchingConstructor = null; + parameterMap = null; + foreach (ConstructorInfo? constructor in instanceType.GetConstructors()) { if (TryCreateParameterMap(constructor.GetParameters(), argumentTypes, out int?[] tempParameterMap)) @@ -407,10 +479,13 @@ private static bool TryFindMatchingConstructor( private static bool TryFindPreferredConstructor( [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType, Type?[] argumentTypes, - [NotNullWhen(true)] ref ConstructorInfo? matchingConstructor, - [NotNullWhen(true)] ref int?[]? parameterMap) + [NotNullWhen(true)] out ConstructorInfo? matchingConstructor, + [NotNullWhen(true)] out int?[]? parameterMap) { bool seenPreferred = false; + matchingConstructor = null; + parameterMap = null; + foreach (ConstructorInfo? constructor in instanceType.GetConstructors()) { if (constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), false)) @@ -642,5 +717,268 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments() { throw new InvalidOperationException(SR.Format(SR.MarkedCtorMissingArgumentTypes, nameof(ActivatorUtilitiesConstructorAttribute))); } + +#if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters. + private static object ReflectionFactoryServiceOnlyFixed( + ConstructorInvoker invoker, + Type[] parameterTypes, + Type declaringType, + IServiceProvider serviceProvider) + { + Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold); + Debug.Assert(FixedArgumentThreshold == 4); + + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + switch (parameterTypes.Length) + { + case 1: + return invoker.Invoke( + GetService(serviceProvider, parameterTypes[0], declaringType, false)); + + case 2: + return invoker.Invoke( + GetService(serviceProvider, parameterTypes[0], declaringType, false), + GetService(serviceProvider, parameterTypes[1], declaringType, false)); + + case 3: + return invoker.Invoke( + GetService(serviceProvider, parameterTypes[0], declaringType, false), + GetService(serviceProvider, parameterTypes[1], declaringType, false), + GetService(serviceProvider, parameterTypes[2], declaringType, false)); + + case 4: + return invoker.Invoke( + GetService(serviceProvider, parameterTypes[0], declaringType, false), + GetService(serviceProvider, parameterTypes[1], declaringType, false), + GetService(serviceProvider, parameterTypes[2], declaringType, false), + GetService(serviceProvider, parameterTypes[3], declaringType, false)); + } + + return null!; + } + + private static object ReflectionFactoryServiceOnlySpan( + ConstructorInvoker invoker, + Type[] parameterTypes, + Type declaringType, + IServiceProvider serviceProvider) + { + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + object?[] arguments = new object?[parameterTypes.Length]; + for (int i = 0; i < parameterTypes.Length; i++) + { + arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false); + } + + return invoker.Invoke(arguments.AsSpan()); + } + + private static object ReflectionFactoryCanonicalFixed( + ConstructorInvoker invoker, + FactoryParameterContext[] parameters, + Type declaringType, + IServiceProvider serviceProvider, + object?[]? arguments) + { + Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold); + Debug.Assert(FixedArgumentThreshold == 4); + + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + ref FactoryParameterContext parameter1 = ref parameters[0]; + + switch (parameters.Length) + { + case 1: + return invoker.Invoke( + ((parameter1.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter1.ArgumentIndex] + : GetService( + serviceProvider, + parameter1.ParameterType, + declaringType, + parameter1.HasDefaultValue)) ?? parameter1.DefaultValue); + case 2: + { + ref FactoryParameterContext parameter2 = ref parameters[1]; + + return invoker.Invoke( + ((parameter1.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter1.ArgumentIndex] + : GetService( + serviceProvider, + parameter1.ParameterType, + declaringType, + parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + ((parameter2.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter2.ArgumentIndex] + : GetService( + serviceProvider, + parameter2.ParameterType, + declaringType, + parameter2.HasDefaultValue)) ?? parameter2.DefaultValue); + } + case 3: + { + ref FactoryParameterContext parameter2 = ref parameters[1]; + ref FactoryParameterContext parameter3 = ref parameters[2]; + + return invoker.Invoke( + ((parameter1.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter1.ArgumentIndex] + : GetService( + serviceProvider, + parameter1.ParameterType, + declaringType, + parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + ((parameter2.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter2.ArgumentIndex] + : GetService( + serviceProvider, + parameter2.ParameterType, + declaringType, + parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + ((parameter3.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter3.ArgumentIndex] + : GetService( + serviceProvider, + parameter3.ParameterType, + declaringType, + parameter3.HasDefaultValue)) ?? parameter3.DefaultValue); + } + case 4: + { + ref FactoryParameterContext parameter2 = ref parameters[1]; + ref FactoryParameterContext parameter3 = ref parameters[2]; + ref FactoryParameterContext parameter4 = ref parameters[3]; + + return invoker.Invoke( + ((parameter1.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter1.ArgumentIndex] + : GetService( + serviceProvider, + parameter1.ParameterType, + declaringType, + parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + ((parameter2.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter2.ArgumentIndex] + : GetService( + serviceProvider, + parameter2.ParameterType, + declaringType, + parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + ((parameter3.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter3.ArgumentIndex] + : GetService( + serviceProvider, + parameter3.ParameterType, + declaringType, + parameter3.HasDefaultValue)) ?? parameter3.DefaultValue, + ((parameter4.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter4.ArgumentIndex] + : GetService( + serviceProvider, + parameter4.ParameterType, + declaringType, + parameter4.HasDefaultValue)) ?? parameter4.DefaultValue); + } + + } + + return null!; + } + + private static object ReflectionFactoryCanonicalSpan( + ConstructorInvoker invoker, + FactoryParameterContext[] parameters, + Type declaringType, + IServiceProvider serviceProvider, + object?[]? arguments) + { + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + object?[] constructorArguments = new object?[parameters.Length]; + for (int i = 0; i < parameters.Length; i++) + { + ref FactoryParameterContext parameter = ref parameters[i]; + constructorArguments[i] = ((parameter.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter.ArgumentIndex] + : GetService( + serviceProvider, + parameter.ParameterType, + declaringType, + parameter.HasDefaultValue)) ?? parameter.DefaultValue; + } + + return invoker.Invoke(constructorArguments.AsSpan()); + } + + private static object ReflectionFactoryDirect( + ConstructorInvoker invoker, + IServiceProvider serviceProvider, + object?[]? arguments) + { + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + if (arguments is null) + ThrowHelperNullReferenceException(); //AsSpan() will not throw NullReferenceException. + + return invoker.Invoke(arguments.AsSpan()); + } + + /// + /// For consistency with the expression-based factory, throw NullReferenceException. + /// + [DoesNotReturn] + private static void ThrowHelperNullReferenceException() + { + throw new NullReferenceException(); + } +#elif NETSTANDARD2_1_OR_GREATER || NETCOREAPP + private static object ReflectionFactoryCanonical( + ConstructorInfo constructor, + FactoryParameterContext[] parameters, + Type declaringType, + IServiceProvider serviceProvider, + object?[]? arguments) + { + if (serviceProvider is null) + ThrowHelperArgumentNullExceptionServiceProvider(); + + object?[] constructorArguments = new object?[parameters.Length]; + for (int i = 0; i < parameters.Length; i++) + { + ref FactoryParameterContext parameter = ref parameters[i]; + constructorArguments[i] = ((parameter.ArgumentIndex != -1) + // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. + ? arguments![parameter.ArgumentIndex] + : GetService( + serviceProvider, + parameter.ParameterType, + declaringType, + parameter.HasDefaultValue)) ?? parameter.DefaultValue; + } + + return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null); + } +#endif // NET8_0_OR_GREATER } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/ActivatorUtilitiesTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/ActivatorUtilitiesTests.cs index 860768f0e4612c..dcc847052b2547 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/ActivatorUtilitiesTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/ActivatorUtilitiesTests.cs @@ -195,7 +195,7 @@ public void TypeActivatorRethrowsOriginalExceptionFromConstructor(CreateInstance CreateInstance(createFunc, provider: serviceProvider)); var ex2 = Assert.Throws(() => - CreateInstance(createFunc, provider: serviceProvider, args: new[] { new FakeService() })); + CreateInstance(createFunc, provider: serviceProvider, args: new object[] { new FakeService() })); // Assert Assert.Equal(nameof(ClassWithThrowingEmptyCtor), ex1.Message); diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs index 4c065b61bb2856..7572e6977a4c49 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs @@ -94,7 +94,8 @@ public void TypeActivatorThrowsOnNullProvider() public void FactoryActivatorThrowsOnNullProvider() { var f = ActivatorUtilities.CreateFactory(typeof(ClassWithA), new Type[0]); - Assert.Throws(() => f(serviceProvider: null, null)); + Exception ex = Assert.Throws(() => f(serviceProvider: null, null)); + Assert.Contains("serviceProvider", ex.ToString()); } [Fact] @@ -179,7 +180,7 @@ public void CreateInstance_ClassWithABC_MultipleCtorsWithSameLength_ThrowsAmbigu } [Fact] - public void CreateFactory_CreatesFactoryMethod() + public void CreateFactory_CreatesFactoryMethod_4Types_3Injected() { var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithABCS), new Type[] { typeof(B) }); var factory2 = ActivatorUtilities.CreateFactory(new Type[] { typeof(B) }); @@ -194,9 +195,42 @@ public void CreateFactory_CreatesFactoryMethod() Assert.IsType(factory1); Assert.IsType(item1); + ClassWithABCS obj = (ClassWithABCS)item1; + Assert.NotNull(obj.A); + Assert.NotNull(obj.B); + Assert.NotNull(obj.C); + Assert.NotNull(obj.S); Assert.IsType>(factory2); Assert.IsType(item2); + + Assert.NotNull(item2.A); + Assert.NotNull(item2.B); + Assert.NotNull(item2.C); + Assert.NotNull(item2.S); + } + + [Fact] + public void CreateFactory_CreatesFactoryMethod_5Types_5Injected() + { + // Inject 5 types which is a threshold for whether fixed or Span<> invoker args are used by reflection. + var factory = ActivatorUtilities.CreateFactory(Type.EmptyTypes); + + var services = new ServiceCollection(); + services.AddSingleton(new A()); + services.AddSingleton(new B()); + services.AddSingleton(new C()); + services.AddSingleton(new S()); + services.AddSingleton(new Z()); + using var provider = services.BuildServiceProvider(); + ClassWithABCSZ item = factory(provider, null); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + Assert.NotNull(item.S); + Assert.NotNull(item.Z); } [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] @@ -365,6 +399,7 @@ internal class A { } internal class B { } internal class C { } internal class S { } + internal class Z { } internal class ClassWithABCS : ClassWithABC { @@ -373,6 +408,12 @@ internal class ClassWithABCS : ClassWithABC public ClassWithABCS(A a, C c, S s) : this(a, null, c, s) { } } + internal class ClassWithABCSZ : ClassWithABCS + { + public Z Z { get; } + public ClassWithABCSZ(A a, B b, C c, S s, Z z) : base(a, b, c, s) { Z = z; } + } + internal class ClassWithABC_FirstConstructorWithAttribute : ClassWithABC { [ActivatorUtilitiesConstructor] diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs index 7ddea4119ed1e4..67fc0d31919759 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs @@ -46,23 +46,63 @@ private ConstructorInvoker(RuntimeConstructorInfo constructor, RuntimeType[] arg Initialize(argumentTypes, out _strategy, out _invokerArgFlags, out _needsByRefStrategy); } - public object Invoke() => Invoke(null, null, null, null); - public object Invoke(object? arg1) => Invoke(arg1, null, null, null); - public object Invoke(object? arg1, object? arg2) => Invoke(arg1, arg2, null, null); - public object Invoke(object? arg1, object? arg2, object? arg3) => Invoke(arg1, arg2, arg3, null); - public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) + public object Invoke() { - if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + if (_argCount != 0) { - _method.ThrowNoInvokeException(); + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(null, null, null, null); + } + + public object Invoke(object? arg1) + { + if (_argCount != 1) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(arg1, null, null, null); + } + + public object Invoke(object? arg1, object? arg2) + { + if (_argCount != 2) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(arg1, arg2, null, null); + } + + public object Invoke(object? arg1, object? arg2, object? arg3) + { + if (_argCount !=3) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); } - // Allow additional non-used arguments to simplify caller's logic. - if (_argCount > MaxStackAllocArgCount) + return InvokeImpl(arg1, arg2, arg3, null); + } + + public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) + { + if (_argCount != 4) { MethodBaseInvoker.ThrowTargetParameterCountException(); } + return InvokeImpl(arg1, arg2, arg3, arg4); + } + + private object InvokeImpl(object? arg1, object? arg2, object? arg3, object? arg4) + { + if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + { + _method.ThrowNoInvokeException(); + } + switch (_argCount) { case 4: @@ -99,21 +139,27 @@ public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) public object Invoke(Span arguments) { + int argLen = arguments.Length; + if (argLen != _argCount) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + if (!_needsByRefStrategy) { // Switch to fast path if possible. switch (_argCount) { case 0: - return Invoke(null, null, null, null); + return InvokeImpl(null, null, null, null); case 1: - return Invoke(arguments[0], null, null, null); + return InvokeImpl(arguments[0], null, null, null); case 2: - return Invoke(arguments[0], arguments[1], null, null); + return InvokeImpl(arguments[0], arguments[1], null, null); case 3: - return Invoke(arguments[0], arguments[1], arguments[2], null); + return InvokeImpl(arguments[0], arguments[1], arguments[2], null); case 4: - return Invoke(arguments[0], arguments[1], arguments[2], arguments[3]); + return InvokeImpl(arguments[0], arguments[1], arguments[2], arguments[3]); default: break; } @@ -124,12 +170,7 @@ public object Invoke(Span arguments) _method.ThrowNoInvokeException(); } - if (arguments.Length != _argCount) - { - throw new TargetParameterCountException(SR.Arg_ParmCnt); - } - - if (arguments.Length > MaxStackAllocArgCount) + if (argLen > MaxStackAllocArgCount) { return InvokeWithManyArgs(arguments); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs index 76483f646c53a1..b5496c37c0cc84 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs @@ -60,23 +60,63 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) Initialize(argumentTypes, out _strategy, out _invokerArgFlags, out _needsByRefStrategy); } - public object? Invoke(object? obj) => Invoke(obj, null, null, null, null); - public object? Invoke(object? obj, object? arg1) => Invoke(obj, arg1, null, null, null); - public object? Invoke(object? obj, object? arg1, object? arg2) => Invoke(obj, arg1, arg2, null, null); - public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3) => Invoke(obj, arg1, arg2, arg3, null); - public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) + public object? Invoke(object? obj) { - if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + if (_argCount != 0) { - ThrowForBadInvocationFlags(); + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(obj, null, null, null, null); + } + + public object? Invoke(object? obj, object? arg1) + { + if (_argCount != 1) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(obj, arg1, null, null, null); + } + + public object? Invoke(object? obj, object? arg1, object? arg2) + { + if (_argCount != 2) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + + return InvokeImpl(obj, arg1, arg2, null, null); + } + + public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3) + { + if (_argCount != 3) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); } - // Allow additional non-used arguments to simplify caller's logic. - if (_argCount > MaxStackAllocArgCount) + return InvokeImpl(obj, arg1, arg2, arg3, null); + } + + public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) + { + if (_argCount != 4) { MethodBaseInvoker.ThrowTargetParameterCountException(); } + return InvokeImpl(obj, arg1, arg2, arg3, arg4); + } + + private object? InvokeImpl(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) + { + if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + { + ThrowForBadInvocationFlags(); + } + if (!_isStatic) { ValidateInvokeTarget(obj, _method); @@ -118,21 +158,27 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) public object? Invoke(object? obj, Span arguments) { + int argLen = arguments.Length; + if (argLen != _argCount) + { + MethodBaseInvoker.ThrowTargetParameterCountException(); + } + if (!_needsByRefStrategy) { // Switch to fast path if possible. switch (_argCount) { case 0: - return Invoke(obj, null, null, null, null); + return InvokeImpl(obj, null, null, null, null); case 1: - return Invoke(obj, arguments[0], null, null, null); + return InvokeImpl(obj, arguments[0], null, null, null); case 2: - return Invoke(obj, arguments[0], arguments[1], null, null); + return InvokeImpl(obj, arguments[0], arguments[1], null, null); case 3: - return Invoke(obj, arguments[0], arguments[1], arguments[2], null); + return InvokeImpl(obj, arguments[0], arguments[1], arguments[2], null); case 4: - return Invoke(obj, arguments[0], arguments[1], arguments[2], arguments[3]); + return InvokeImpl(obj, arguments[0], arguments[1], arguments[2], arguments[3]); default: break; } @@ -143,17 +189,12 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) ThrowForBadInvocationFlags(); } - if (arguments.Length != _argCount) - { - throw new TargetParameterCountException(SR.Arg_ParmCnt); - } - if (!_isStatic) { ValidateInvokeTarget(obj, _method); } - if (arguments.Length > MaxStackAllocArgCount) + if (argLen > MaxStackAllocArgCount) { return InvokeWithManyArgs(obj, arguments); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeConstructorInfo.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeConstructorInfo.cs index af741f377b55b7..ebd93b6591ef63 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeConstructorInfo.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeConstructorInfo.cs @@ -115,7 +115,7 @@ internal void ThrowNoInvokeException() int argCount = (parameters is null) ? 0 : parameters.Length; if (ArgumentTypes.Length != argCount) { - throw new TargetParameterCountException(SR.Arg_ParmCnt); + MethodBaseInvoker.ThrowTargetParameterCountException(); } if ((InvocationFlags & InvocationFlags.RunClassConstructor) != 0) @@ -147,7 +147,7 @@ public override object Invoke(BindingFlags invokeAttr, Binder? binder, object?[] int argCount = (parameters is null) ? 0 : parameters.Length; if (ArgumentTypes.Length != argCount) { - throw new TargetParameterCountException(SR.Arg_ParmCnt); + MethodBaseInvoker.ThrowTargetParameterCountException(); } switch (argCount) diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeMethodInfo.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeMethodInfo.cs index d22bf589cf3e0d..e29038997421a9 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeMethodInfo.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/RuntimeMethodInfo.cs @@ -118,7 +118,7 @@ internal void ThrowNoInvokeException() int argCount = (parameters is null) ? 0 : parameters.Length; if (ArgumentTypes.Length != argCount) { - throw new TargetParameterCountException(SR.Arg_ParmCnt); + MethodBaseInvoker.ThrowTargetParameterCountException(); } switch (argCount) diff --git a/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs b/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs index 056bb592687163..86fa1c4012c270 100644 --- a/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs +++ b/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs @@ -64,6 +64,80 @@ public void Args_5() Assert.Equal("12345", ((TestClass)invoker.Invoke(new Span(new object[] { "1", "2", "3", "4", "5" })))._args); } + [Fact] + public void Args_0_Extra_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor(new Type[] { })); + Assert.Throws(() => invoker.Invoke(42)); + } + + [Fact] + public void Args_1_Extra_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor(new Type[] { typeof(string) })); + Assert.Throws(() => invoker.Invoke("1", 42)); + } + + [Fact] + public void Args_2_Extra_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor( + new Type[] { typeof(string), typeof(string) })); + + Assert.Throws(() => invoker.Invoke("1", "2", 42)); + } + + [Fact] + public void Args_3_Extra_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor( + new Type[] { typeof(string), typeof(string), typeof(string) })); + + Assert.Throws(() => invoker.Invoke("1", "2", "3", 42)); + } + + [Fact] + public void Args_Span_Extra_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor(new Type[] { })); + Assert.Throws(() => invoker.Invoke(new Span(new object[]{"1", "2"}))); + } + + [Fact] + public void Args_1_NotEnoughArgs_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor(new Type[] { typeof(string) })); + Assert.Throws(invoker.Invoke); + } + + [Fact] + public void Args_2_NotEnoughArgs_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor( + new Type[] { typeof(string), typeof(string) })); + + Assert.Throws(invoker.Invoke); + Assert.Throws(() => invoker.Invoke("1")); + } + + [Fact] + public void Args_3_NotEnoughArgs_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor( + new Type[] { typeof(string), typeof(string), typeof(string) })); + + Assert.Throws(invoker.Invoke); + Assert.Throws(() => invoker.Invoke("1")); + Assert.Throws(() => invoker.Invoke("1", "2")); + } + + [Fact] + public void Args_Span_NotEnoughArgs_Throws() + { + ConstructorInvoker invoker = ConstructorInvoker.Create(typeof(TestClass).GetConstructor(new Type[] { typeof(string) })); + Assert.Throws(() => invoker.Invoke(new Span())); + } + [Fact] public void ThrowsNonWrappedException_0() { diff --git a/src/libraries/System.Reflection/tests/MethodInvokerTests.cs b/src/libraries/System.Reflection/tests/MethodInvokerTests.cs index 0eb20799d2cb7c..97c9865a64fdaf 100644 --- a/src/libraries/System.Reflection/tests/MethodInvokerTests.cs +++ b/src/libraries/System.Reflection/tests/MethodInvokerTests.cs @@ -62,6 +62,72 @@ public void Args_5() Assert.Equal("12345", invoker.Invoke(obj: null, new Span(new object[] { "1", "2", "3", "4", "5" }))); } + [Fact] + public void Args_0_Extra_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_0))); + Assert.Throws(() => invoker.Invoke(obj: null, 42)); + } + + [Fact] + public void Args_1_Extra_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_1))); + Assert.Throws(() => invoker.Invoke(obj: null, "1", 42)); + } + + [Fact] + public void Args_2_Extra_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_2))); + Assert.Throws(() => invoker.Invoke(obj: null, "1", "2", 42)); + } + + [Fact] + public void Args_3_Extra_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_3))); + Assert.Throws(() => invoker.Invoke(obj: null, "1", "2", "3", 42)); + } + + [Fact] + public void Args_Span_Extra_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_1))); + Assert.Throws(() => invoker.Invoke(obj: null, new Span(new object[] { "1", "2" }))); + } + + [Fact] + public void Args_1_NotEnoughArgs_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_1))); + Assert.Throws(() => invoker.Invoke(obj: null)); + } + + [Fact] + public void Args_2_NotEnoughArgs_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_2))); + Assert.Throws(() => invoker.Invoke(obj: null)); + Assert.Throws(() => invoker.Invoke("1")); + } + + [Fact] + public void Args_3_NotEnoughArgs_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_3))); + Assert.Throws(() => invoker.Invoke(obj: null)); + Assert.Throws(() => invoker.Invoke(obj: null, "1")); + Assert.Throws(() => invoker.Invoke(obj: null, "1", "2")); + } + + [Fact] + public void Args_Span_NotEnoughArgs_Throws() + { + MethodInvoker invoker = MethodInvoker.Create(typeof(TestClass).GetMethod(nameof(TestClass.Args_1))); + Assert.Throws(() => invoker.Invoke(obj: null, new Span())); + } + [Fact] public void Args_ByRef() {