diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs index a06b565d1f3ca7..51b32c18053ac1 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs @@ -603,69 +603,108 @@ private ConstructorCallSite CreateConstructorCallSite( return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, constructor, parameterCallSites, serviceIdentifier.ServiceKey); } - Array.Sort(constructors, - (a, b) => b.GetParameters().Length.CompareTo(a.GetParameters().Length)); + // With more than one constructor, select the "best" one: the constructor with the + // most parameters whose arguments can all be resolved, either from the container or + // by falling back to a default parameter value. Constructors are ordered from the + // most parameters to the fewest (preserving declaration order among equal counts), + // and the first fully resolvable one becomes the best match. Every subsequent + // resolvable constructor must have parameters that are a subset of the best one, + // i.e. each of its parameters was already used by the best constructor; otherwise + // the two are ambiguous and an exception is thrown. + int constructorCount = constructors.Length; + var sortedParameters = new ParameterInfo[constructorCount][]; + for (int i = 0; i < constructorCount; i++) + { + ConstructorInfo constructor = constructors[i]; + ParameterInfo[] parameters = constructor.GetParameters(); + + int sortedIndex = i; + while (sortedIndex > 0 && sortedParameters[sortedIndex - 1].Length < parameters.Length) + { + constructors[sortedIndex] = constructors[sortedIndex - 1]; + sortedParameters[sortedIndex] = sortedParameters[sortedIndex - 1]; + sortedIndex--; + } + + constructors[sortedIndex] = constructor; + sortedParameters[sortedIndex] = parameters; + } ConstructorInfo? bestConstructor = null; - HashSet? bestConstructorParameterTypes = null; - for (int i = 0; i < constructors.Length; i++) + List? bestResolvedParameters = null; + List? resolvedParameters = null; + + for (int i = 0; i < constructorCount; i++) { - ParameterInfo[] parameters = constructors[i].GetParameters(); + ConstructorInfo constructor = constructors[i]; + ParameterInfo[] parameters = sortedParameters[i]; + + if (bestConstructor is null) + { + var currentResolvedParameters = new List(parameters.Length); + ServiceCallSite[]? currentParameterCallSites = CreateArgumentCallSites( + serviceIdentifier, + implementationType, + callSiteChain, + parameters, + throwIfCallSiteNotFound: false, + currentResolvedParameters); + + if (currentParameterCallSites is null) + { + continue; + } - ServiceCallSite[]? currentParameterCallSites = CreateArgumentCallSites( + bestConstructor = constructor; + parameterCallSites = currentParameterCallSites; + bestResolvedParameters = currentResolvedParameters; + continue; + } + + Debug.Assert(bestResolvedParameters is not null); + + if (resolvedParameters is null) + { + resolvedParameters = new List(parameters.Length); + } + else + { + resolvedParameters.Clear(); + } + + if (CreateArgumentCallSites( serviceIdentifier, implementationType, callSiteChain, parameters, - throwIfCallSiteNotFound: false); + throwIfCallSiteNotFound: false, + resolvedParameters) is null) + { + continue; + } - if (currentParameterCallSites != null) + // All parameters resolvable; ambiguous unless it is a subset of best. + foreach (ServiceIdentifier id in resolvedParameters) { - if (bestConstructor == null) + if (!bestResolvedParameters.Contains(id)) { - bestConstructor = constructors[i]; - parameterCallSites = currentParameterCallSites; - } - else - { - // Since we're visiting constructors in decreasing order of number of parameters, - // we'll only see ambiguities or supersets once we've seen a 'bestConstructor'. - - if (bestConstructorParameterTypes == null) - { - bestConstructorParameterTypes = new HashSet(); - foreach (ParameterInfo p in bestConstructor.GetParameters()) - { - bestConstructorParameterTypes.Add(p.ParameterType); - } - } - - foreach (ParameterInfo p in parameters) - { - if (!bestConstructorParameterTypes.Contains(p.ParameterType)) - { - // Ambiguous match exception - throw new InvalidOperationException(string.Join( - Environment.NewLine, - SR.Format(SR.AmbiguousConstructorException, implementationType), - bestConstructor, - constructors[i])); - } - } + throw new InvalidOperationException(string.Join( + Environment.NewLine, + SR.Format(SR.AmbiguousConstructorException, implementationType), + bestConstructor, + constructor)); } } } - if (bestConstructor == null) + if (bestConstructor is null) { throw new InvalidOperationException( SR.Format(SR.UnableToActivateTypeException, implementationType)); } - else - { - Debug.Assert(parameterCallSites != null); - return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, bestConstructor, parameterCallSites, serviceIdentifier.ServiceKey); - } + + Debug.Assert(parameterCallSites != null); + return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, bestConstructor, parameterCallSites, serviceIdentifier.ServiceKey); } finally { @@ -679,80 +718,104 @@ private ConstructorCallSite CreateConstructorCallSite( Type implementationType, CallSiteChain callSiteChain, ParameterInfo[] parameters, - bool throwIfCallSiteNotFound) + bool throwIfCallSiteNotFound, + List? resolvedParameters = null) { var parameterCallSites = new ServiceCallSite[parameters.Length]; for (int index = 0; index < parameters.Length; index++) { - ServiceCallSite? callSite = null; - bool isKeyedParameter = false; - Type parameterType = parameters[index].ParameterType; - foreach (var attribute in parameters[index].GetCustomAttributes(true)) + if (!TryResolveCallSite( + serviceIdentifier, + implementationType, + callSiteChain, + parameters[index], + throwIfCallSiteNotFound, + out ServiceIdentifier parameterServiceIdentifier, + out ServiceCallSite? callSite)) { - if (serviceIdentifier.ServiceKey != null && attribute is ServiceKeyAttribute) - { - // Even though the parameter may be strongly typed, support 'object' if AnyKey is used. + return null; + } - if (serviceIdentifier.ServiceKey == KeyedService.AnyKey) - { - parameterType = typeof(object); - } - else if (parameterType != serviceIdentifier.ServiceKey.GetType() - && parameterType != typeof(object)) - { - throw new InvalidOperationException(SR.InvalidServiceKeyType); - } + Debug.Assert(callSite is not null); + resolvedParameters?.Add(parameterServiceIdentifier); + parameterCallSites[index] = callSite; + } - callSite = new ConstantCallSite(parameterType, serviceIdentifier.ServiceKey); - break; - } + return parameterCallSites; + } - if (attribute is FromKeyedServicesAttribute fromKeyedServicesAttribute) - { - object? serviceKey = fromKeyedServicesAttribute.LookupMode switch - { - ServiceKeyLookupMode.InheritKey => serviceIdentifier.ServiceKey, - ServiceKeyLookupMode.ExplicitKey => fromKeyedServicesAttribute.Key, - ServiceKeyLookupMode.NullKey => null, - _ => null - }; + private bool TryResolveCallSite( + ServiceIdentifier serviceIdentifier, + Type implementationType, + CallSiteChain callSiteChain, + ParameterInfo parameter, + bool throwIfCallSiteNotFound, + out ServiceIdentifier parameterServiceIdentifier, + out ServiceCallSite? callSite) + { + Type parameterType = parameter.ParameterType; + parameterServiceIdentifier = ServiceIdentifier.FromServiceType(parameterType); - if (serviceKey is not null) - { - callSite = GetCallSite(new ServiceIdentifier(serviceKey, parameterType), callSiteChain); - isKeyedParameter = true; - break; - } + foreach (object attribute in parameter.GetCustomAttributes(true)) + { + if (serviceIdentifier.ServiceKey != null && attribute is ServiceKeyAttribute) + { + // Even though the parameter may be strongly typed, support 'object' if AnyKey is used. + if (serviceIdentifier.ServiceKey == KeyedService.AnyKey) + { + parameterType = typeof(object); + } + else if (parameterType != serviceIdentifier.ServiceKey.GetType() && + parameterType != typeof(object)) + { + throw new InvalidOperationException(SR.InvalidServiceKeyType); } - } - if (!isKeyedParameter) - { - callSite ??= GetCallSite(ServiceIdentifier.FromServiceType(parameterType), callSiteChain); + parameterServiceIdentifier = new ServiceIdentifier(serviceIdentifier.ServiceKey, parameterType); + callSite = new ConstantCallSite(parameterType, serviceIdentifier.ServiceKey); + return true; } - if (callSite == null && ParameterDefaultValue.TryGetDefaultValue(parameters[index], out object? defaultValue)) + if (attribute is FromKeyedServicesAttribute fromKeyedServicesAttribute) { - callSite = new ConstantCallSite(parameterType, defaultValue); - } + object? serviceKey = fromKeyedServicesAttribute.LookupMode switch + { + ServiceKeyLookupMode.InheritKey => serviceIdentifier.ServiceKey, + ServiceKeyLookupMode.ExplicitKey => fromKeyedServicesAttribute.Key, + ServiceKeyLookupMode.NullKey => null, + _ => null + }; - if (callSite == null) - { - if (throwIfCallSiteNotFound) + if (serviceKey is not null) { - throw new InvalidOperationException(SR.Format(SR.CannotResolveService, - parameterType, - implementationType)); + parameterServiceIdentifier = new ServiceIdentifier(serviceKey, parameterType); } - - return null; } + } - parameterCallSites[index] = callSite; + ServiceCallSite? parameterCallSite = GetCallSite(parameterServiceIdentifier, callSiteChain); + if (parameterCallSite is not null) + { + callSite = parameterCallSite; + return true; } - return parameterCallSites; + if (ParameterDefaultValue.TryGetDefaultValue(parameter, out object? defaultValue)) + { + callSite = new ConstantCallSite(parameterType, defaultValue); + return true; + } + + if (throwIfCallSiteNotFound) + { + throw new InvalidOperationException(SR.Format(SR.CannotResolveService, + parameterType, + implementationType)); + } + + callSite = null; + return false; } /// diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs index c8ed86dd37cd3d..30045000c0c345 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs @@ -137,6 +137,139 @@ public void CreateCallSite_UsesNullaryConstructorIfServicesCannotBeInjectedIntoO Assert.Empty(ctorCallSite.ParameterCallSites); } + [Fact] + public void ServiceProvider_UsesDeclarationOrderForSameArityConstructorsWithSameParameterTypes() + { + var collection = new ServiceCollection(); + collection.AddTransient(); + collection.AddTransient(); + collection.AddTransient(); + + using ServiceProvider provider = collection.BuildServiceProvider(); + var service = provider.GetRequiredService(); + + Assert.Equal(1, service.SelectedConstructor); + } + + [Fact] + public void CreateCallSite_ThrowsIfMultipleSameArityDisjointConstructorsCanBeResolved() + { + // Arrange + var type = typeof(TypeWithSameArityDisjointConstructors); + var expectedMessage = + string.Join( + Environment.NewLine, + $"Unable to activate type '{type}'. The following constructors are ambiguous:", + GetConstructor(type, new[] { typeof(IFakeService), typeof(IFakeScopedService) }), + GetConstructor(type, new[] { typeof(IFactoryService), typeof(IFakeMultipleService) })); + + var callSiteFactory = GetCallSiteFactory( + new ServiceDescriptor(type, type, ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeScopedService), typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFactoryService), typeof(TransientFactoryService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeMultipleService), typeof(FakeService), ServiceLifetime.Transient)); + + // Act and Assert + var ex = Assert.Throws( + () => callSiteFactory(type)); + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData(typeof(TypeWithCrossArityDisjointConstructorsLongFirst))] + [InlineData(typeof(TypeWithCrossArityDisjointConstructorsShortFirst))] + public void CreateCallSite_ThrowsIfCrossArityDisjointConstructorsCanBeResolved(Type type) + { + // Arrange + var expectedMessage = + string.Join( + Environment.NewLine, + $"Unable to activate type '{type}'. The following constructors are ambiguous:", + GetConstructor(type, new[] { typeof(IFakeService), typeof(IFactoryService), typeof(IFakeScopedService) }), + GetConstructor(type, new[] { typeof(IFakeOuterService) })); + + var callSiteFactory = GetCallSiteFactory( + new ServiceDescriptor(type, type, ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFactoryService), typeof(TransientFactoryService), ServiceLifetime.Transient)); + + // Act and Assert + var ex = Assert.Throws( + () => callSiteFactory(type)); + Assert.Equal(expectedMessage, ex.Message); + } + + [Fact] + public void CreateCallSite_ThrowsIfSameTypeParametersUseDifferentServiceKeys() + { + // Arrange + var type = typeof(TypeWithSameTypeDifferentServiceKeyConstructors); + var callSiteFactory = GetCallSiteFactory( + new ServiceDescriptor(type, type, ServiceLifetime.Transient), + ServiceDescriptor.KeyedTransient("a"), + ServiceDescriptor.KeyedTransient("b"), + new ServiceDescriptor(typeof(IFakeScopedService), typeof(FakeService), ServiceLifetime.Transient)); + + // Act and Assert + var ex = Assert.Throws(() => callSiteFactory(type)); + Assert.StartsWith($"Unable to activate type '{type}'. The following constructors are ambiguous:", ex.Message); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void CreateCallSite_IgnoresSmallerConstructorWhenBestUsesDefaultAndSmallerNeedsContainer(bool registerScopedService) + { + // Arrange + var type = typeof(TypeWithDefaultInBestAndNonDefaultInSmallerConstructors); + ServiceDescriptor[] descriptors = registerScopedService + ? new[] + { + new ServiceDescriptor(type, type, ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeScopedService), typeof(FakeService), ServiceLifetime.Transient), + } + : new[] + { + new ServiceDescriptor(type, type, ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFakeService), typeof(FakeService), ServiceLifetime.Transient), + }; + var callSiteFactory = GetCallSiteFactory(descriptors); + + // Act + var callSite = callSiteFactory(type); + + // Assert + var constructorCallSite = Assert.IsType(callSite); + Assert.Equal(new[] { typeof(IFakeService), typeof(IFakeScopedService) }, GetParameters(constructorCallSite)); + } + + [Theory] + [InlineData(typeof(TypeWithShortThenLongResolvableConstructors), true)] + [InlineData(typeof(TypeWithShortThenLongUnresolvableConstructors), false)] + public void CreateCallSite_UsesLongestResolvableConstructorWhenDeclaredAfterShorter(Type type, bool expectLongConstructor) + { + // Arrange + var descriptor = new ServiceDescriptor(type, type, ServiceLifetime.Transient); + var callSiteFactory = GetCallSiteFactory( + descriptor, + new ServiceDescriptor(typeof(IFakeService), typeof(FakeService), ServiceLifetime.Transient), + new ServiceDescriptor(typeof(IFactoryService), typeof(TransientFactoryService), ServiceLifetime.Transient)); + + Type[] expectedParameters = expectLongConstructor + ? new[] { typeof(IFakeService), typeof(IFactoryService) } + : new[] { typeof(IFakeService) }; + + // Act + var callSite = callSiteFactory(type); + + // Assert + Assert.Equal(CallSiteResultCacheLocation.Dispose, callSite.Cache.Location); + var constructorCallSite = Assert.IsType(callSite); + Assert.Equal(expectedParameters, GetParameters(constructorCallSite)); + } + [Fact] public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyStructGenericConstraint() { @@ -1021,6 +1154,97 @@ private class Class2 { public Class2(Class3 c3) { } } private class Class3 { } private class Class4 { public Class4(Class3 c3) { } } private class Class5 { public Class5(Class2 c2) { } } + private class TypeWithSameArityConstructorsWithSameParameterTypes + { + public TypeWithSameArityConstructorsWithSameParameterTypes(IFakeService fakeService, IFakeMultipleService fakeMultipleService) + { + SelectedConstructor = 1; + } + + public TypeWithSameArityConstructorsWithSameParameterTypes(IFakeMultipleService fakeMultipleService, IFakeService fakeService) + { + SelectedConstructor = 2; + } + + public int SelectedConstructor { get; } + } + + private class TypeWithSameArityDisjointConstructors + { + public TypeWithSameArityDisjointConstructors(IFakeService fakeService, IFakeScopedService fakeScopedService) + { + } + + public TypeWithSameArityDisjointConstructors(IFactoryService factoryService, IFakeMultipleService fakeMultipleService) + { + } + } + + private class TypeWithShortThenLongResolvableConstructors + { + public TypeWithShortThenLongResolvableConstructors(IFakeService fakeService) + { + } + + public TypeWithShortThenLongResolvableConstructors(IFakeService fakeService, IFactoryService factoryService) + { + } + } + + private class TypeWithShortThenLongUnresolvableConstructors + { + public TypeWithShortThenLongUnresolvableConstructors(IFakeService fakeService) + { + } + + public TypeWithShortThenLongUnresolvableConstructors(IFakeService fakeService, IFakeOuterService fakeOuterService) + { + } + } + + private class TypeWithCrossArityDisjointConstructorsLongFirst + { + public TypeWithCrossArityDisjointConstructorsLongFirst(IFakeService fakeService, IFactoryService factoryService, IFakeScopedService fakeScopedService = null) + { + } + + public TypeWithCrossArityDisjointConstructorsLongFirst(IFakeOuterService fakeOuterService = null) + { + } + } + + private class TypeWithCrossArityDisjointConstructorsShortFirst + { + public TypeWithCrossArityDisjointConstructorsShortFirst(IFakeOuterService fakeOuterService = null) + { + } + + public TypeWithCrossArityDisjointConstructorsShortFirst(IFakeService fakeService, IFactoryService factoryService, IFakeScopedService fakeScopedService = null) + { + } + } + + private class TypeWithSameTypeDifferentServiceKeyConstructors + { + public TypeWithSameTypeDifferentServiceKeyConstructors([FromKeyedServices("a")] IFakeService service, IFakeScopedService scoped) + { + } + + public TypeWithSameTypeDifferentServiceKeyConstructors([FromKeyedServices("b")] IFakeService service, IFakeScopedService scoped, int dummy = 0) + { + } + } + + private class TypeWithDefaultInBestAndNonDefaultInSmallerConstructors + { + public TypeWithDefaultInBestAndNonDefaultInSmallerConstructors(IFakeService fakeService, IFakeScopedService fakeScopedService = null) + { + } + + public TypeWithDefaultInBestAndNonDefaultInSmallerConstructors(IFakeScopedService fakeScopedService) + { + } + } private record struct Struct1(int Value) { }