diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 45fc359bf6b5f0..b4bc49672185bf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -25,6 +25,7 @@ public static class StepNames public const string GenerateNativeToManagedStub = nameof(GenerateNativeToManagedStub); public const string GenerateManagedToNativeInterfaceImplementation = nameof(GenerateManagedToNativeInterfaceImplementation); public const string GenerateNativeToManagedVTableMethods = nameof(GenerateNativeToManagedVTableMethods); + public const string GenerateNativeToManagedVTableStruct = nameof(GenerateNativeToManagedVTableStruct); public const string GenerateNativeToManagedVTable = nameof(GenerateNativeToManagedVTable); public const string GenerateInterfaceInformation = nameof(GenerateInterfaceInformation); public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); @@ -122,6 +123,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .SelectNormalized(); // Generate the code for the unmanaged-to-managed stubs. + var nativeToManagedVtableStructs = interfaceAndMethodsContexts + .Select(GenerateInterfaceImplementationVtable) + .WithTrackingName(StepNames.GenerateNativeToManagedVTableStruct) + .WithComparer(SyntaxEquivalentComparer.Instance) + .SelectNormalized(); + var nativeToManagedVtableMethods = interfaceAndMethodsContexts .Select(GenerateImplementationVTableMethods) .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods) @@ -175,16 +182,17 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Zip(nativeInterfaceInformation) .Zip(managedToNativeInterfaceImplementations) .Zip(nativeToManagedVtableMethods) + .Zip(nativeToManagedVtableStructs) .Zip(nativeToManagedVtables) .Zip(iUnknownDerivedAttributeApplication) .Zip(shadowingMethodDeclarations) .Select(static (data, ct) => { - var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data; + var (((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedStructs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data; using StringWriter source = new(); source.WriteLine("// "); - source.WriteLine("#pragma warning disable CS0612, CS0618"); // Suppress warnings about [Obsolete] member usage in generated code. + source.WriteLine("#pragma warning disable CS0612, CS0618, CS0649"); // Suppress warnings about [Obsolete] and "lack of assignment" in generated code. // If the user has specified 'ManagedObjectWrapper', it means that the COM interface will never be used to marshal a native // object as an RCW (eg. the IDIC vtable will also not be generated, nor any additional supporting code). To reduce binary @@ -208,6 +216,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) nativeToManagedStubs.WriteTo(source); source.WriteLine(); source.WriteLine(); + nativeToManagedStructs.WriteTo(source); + source.WriteLine(); + source.WriteLine(); nativeToManagedVtable.WriteTo(source); source.WriteLine(); source.WriteLine(); @@ -540,10 +551,106 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co .Select(context => context.Stub.Node))); } - private const string CreateManagedVirtualFunctionTableMethodName = "CreateManagedVirtualFunctionTable"; + private static readonly StructDeclarationSyntax InterfaceImplementationVtableTemplate = StructDeclaration("InterfaceImplementationVtable") + .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.UnsafeKeyword))); - private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(TypeSyntaxes.VoidStarStar, CreateManagedVirtualFunctionTableMethodName) - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); + private static StructDeclarationSyntax GenerateInterfaceImplementationVtable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _) + { + StructDeclarationSyntax vtableDeclaration = + InterfaceImplementationVtableTemplate + .AddMembers( + FieldDeclaration( + VariableDeclaration( + FunctionPointerType( + FunctionPointerCallingConvention( + Token(SyntaxKind.UnmanagedKeyword), + FunctionPointerUnmanagedCallingConventionList( + SingletonSeparatedList( + FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction"))))), + FunctionPointerParameterList( + SeparatedList([ + FunctionPointerParameter(TypeSyntaxes.VoidStar), + FunctionPointerParameter(PointerType(TypeSyntaxes.System_Guid)), + FunctionPointerParameter(TypeSyntaxes.VoidStarStar), + FunctionPointerParameter(ParseTypeName("int"))])))) + .AddVariables(VariableDeclarator("QueryInterface_0"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword))), + FieldDeclaration( + VariableDeclaration( + FunctionPointerType( + FunctionPointerCallingConvention( + Token(SyntaxKind.UnmanagedKeyword), + FunctionPointerUnmanagedCallingConventionList( + SingletonSeparatedList( + FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction"))))), + FunctionPointerParameterList( + SeparatedList([ + FunctionPointerParameter(TypeSyntaxes.VoidStar), + FunctionPointerParameter(ParseTypeName("uint"))])))) + .AddVariables(VariableDeclarator("AddRef_1"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword))), + FieldDeclaration( + VariableDeclaration( + FunctionPointerType( + FunctionPointerCallingConvention( + Token(SyntaxKind.UnmanagedKeyword), + FunctionPointerUnmanagedCallingConventionList( + SingletonSeparatedList( + FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction"))))), + FunctionPointerParameterList( + SeparatedList([ + FunctionPointerParameter(TypeSyntaxes.VoidStar), + FunctionPointerParameter(ParseTypeName("uint"))])))) + .AddVariables(VariableDeclarator("Release_2"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword)))) + .AddAttributeLists( + AttributeList( + SingletonSeparatedList( + Attribute( + NameSyntaxes.System_Runtime_InteropServices_StructLayoutAttribute, + AttributeArgumentList( + SingletonSeparatedList( + AttributeArgument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_InteropServices_LayoutKind, + IdentifierName("Sequential"))))))))); + + if (interfaceMethods.Interface.Base is not null) + { + foreach (ComMethodContext inheritedMethod in interfaceMethods.InheritedMethods) + { + FunctionPointerTypeSyntax functionPointerType = VirtualMethodPointerStubGenerator.GenerateUnmanagedFunctionPointerTypeForMethod( + inheritedMethod.GenerationContext, + ComInterfaceGeneratorHelpers.GetGeneratorResolver); + + vtableDeclaration = vtableDeclaration + .AddMembers( + FieldDeclaration( + VariableDeclaration(functionPointerType) + .AddVariables(VariableDeclarator($"{inheritedMethod.MethodInfo.MethodName}_{inheritedMethod.GenerationContext.VtableIndexData.Index}"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword)))); + } + } + + foreach (ComMethodContext declaredMethod in + interfaceMethods.DeclaredMethods + .Where(context => context.UnmanagedToManagedStub.Diagnostics.All(diag => diag.Descriptor.DefaultSeverity != DiagnosticSeverity.Error))) + { + FunctionPointerTypeSyntax functionPointerType = VirtualMethodPointerStubGenerator.GenerateUnmanagedFunctionPointerTypeForMethod( + declaredMethod.GenerationContext, + ComInterfaceGeneratorHelpers.GetGeneratorResolver); + + vtableDeclaration = vtableDeclaration + .AddMembers( + FieldDeclaration( + VariableDeclaration(functionPointerType) + .AddVariables(VariableDeclarator($"{declaredMethod.MethodInfo.MethodName}_{declaredMethod.GenerationContext.VtableIndexData.Index}"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword)))); + } + + return vtableDeclaration; + } private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _) { @@ -552,25 +659,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf return ImplementationInterfaceTemplate; } - const string vtableLocalName = "vtable"; - var interfaceType = interfaceMethods.Interface.Info.Type; - - // void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(, sizeof(void*) * ); - var vtableDeclarationStatement = - Declare( - TypeSyntaxes.VoidStarStar, - vtableLocalName, - CastExpression(TypeSyntaxes.VoidStarStar, - MethodInvocation( - TypeSyntaxes.System_Runtime_CompilerServices_RuntimeHelpers, - IdentifierName("AllocateTypeAssociatedMemory"), - Argument(TypeOfExpression(interfaceType.Syntax)), - Argument( - BinaryExpression( - SyntaxKind.MultiplyExpression, - SizeOfExpression(TypeSyntaxes.VoidStar), - IntLiteral(3 + interfaceMethods.Methods.Length)))))); - BlockSyntax fillBaseInterfaceSlots; @@ -579,38 +667,76 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf // If we don't have a base interface, we need to manually fill in the base iUnknown slots. fillBaseInterfaceSlots = Block() .AddStatements( - // nint v0, v1, v2; - LocalDeclarationStatement(VariableDeclaration(ParseTypeName("nint")) - .AddVariables( - VariableDeclarator("v0"), - VariableDeclarator("v1"), - VariableDeclarator("v2") - )), - // ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); + // ComWrappers.GetIUnknownImpl( + // out *(nint*)&((InterfaceImplementationVtable*)Unsafe.AsPointer(ref Vtable))->QueryInterface_0, + // out *(nint*)&((InterfaceImplementationVtable*)Unsafe.AsPointer(ref Vtable))->AddRef_1, + // out *(nint*)&((InterfaceImplementationVtable*)Unsafe.AsPointer(ref Vtable))->Release_2); MethodInvocationStatement( TypeSyntaxes.System_Runtime_InteropServices_ComWrappers, IdentifierName("GetIUnknownImpl"), - OutArgument(IdentifierName("v0")), - OutArgument(IdentifierName("v1")), - OutArgument(IdentifierName("v2"))), - // m_vtable[0] = (void*)v0; - AssignmentStatement( - IndexExpression( - IdentifierName(vtableLocalName), - Argument(IntLiteral(0))), - CastExpression(TypeSyntaxes.VoidStar, IdentifierName("v0"))), - // m_vtable[1] = (void*)v1; - AssignmentStatement( - IndexExpression( - IdentifierName(vtableLocalName), - Argument(IntLiteral(1))), - CastExpression(TypeSyntaxes.VoidStar, IdentifierName("v1"))), - // m_vtable[2] = (void*)v2; - AssignmentStatement( - IndexExpression( - IdentifierName(vtableLocalName), - Argument(IntLiteral(2))), - CastExpression(TypeSyntaxes.VoidStar, IdentifierName("v2")))); + OutArgument( + PrefixUnaryExpression( + SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(ParseTypeName("nint")), + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + MemberAccessExpression( + SyntaxKind.PointerMemberAccessExpression, + ParenthesizedExpression( + CastExpression( + PointerType(ParseTypeName("InterfaceImplementationVtable")), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_CompilerServices_Unsafe, + IdentifierName("AsPointer"))) + .AddArgumentListArguments( + Argument( + RefExpression(IdentifierName("Vtable")))))), + IdentifierName("QueryInterface_0")))))), + OutArgument( + PrefixUnaryExpression( + SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(ParseTypeName("nint")), + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + MemberAccessExpression( + SyntaxKind.PointerMemberAccessExpression, + ParenthesizedExpression( + CastExpression( + PointerType(ParseTypeName("InterfaceImplementationVtable")), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_CompilerServices_Unsafe, + IdentifierName("AsPointer"))) + .AddArgumentListArguments( + Argument( + RefExpression(IdentifierName("Vtable")))))), + IdentifierName("AddRef_1")))))), + OutArgument( + PrefixUnaryExpression( + SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(ParseTypeName("nint")), + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + MemberAccessExpression( + SyntaxKind.PointerMemberAccessExpression, + ParenthesizedExpression( + CastExpression( + PointerType(ParseTypeName("InterfaceImplementationVtable")), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_CompilerServices_Unsafe, + IdentifierName("AsPointer"))) + .AddArgumentListArguments( + Argument( + RefExpression(IdentifierName("Vtable")))))), + IdentifierName("Release_2")))))))); } else { @@ -628,7 +754,15 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf TypeOfExpression(ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName)) .Dot(IdentifierName("TypeHandle")))) .Dot(IdentifierName("ManagedVirtualMethodTable"))), - Argument(IdentifierName(vtableLocalName)), + Argument( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_CompilerServices_Unsafe, + IdentifierName("AsPointer"))) + .AddArgumentListArguments( + Argument( + RefExpression(IdentifierName("Vtable"))))), Argument(CastExpression(IdentifierName("nuint"), ParenthesizedExpression( BinaryExpression(SyntaxKind.MultiplyExpression, @@ -636,22 +770,42 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3)))))))); } - var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments( - interfaceMethods.DeclaredMethods - .Where(context => context.UnmanagedToManagedStub.Diagnostics.All(diag => diag.Descriptor.DefaultSeverity != DiagnosticSeverity.Error)) - .Select(context => context.GenerationContext), - vtableLocalName, - ComInterfaceGeneratorHelpers.GetGeneratorResolver); + var validDeclaredMethods = interfaceMethods.DeclaredMethods + .Where(context => context.UnmanagedToManagedStub.Diagnostics.All(diag => diag.Descriptor.DefaultSeverity != DiagnosticSeverity.Error)); + + System.Collections.Generic.List statements = new(); + + foreach (var declaredMethodContext in validDeclaredMethods) + { + statements.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("Vtable"), + IdentifierName($"{declaredMethodContext.MethodInfo.MethodName}_{declaredMethodContext.GenerationContext.VtableIndexData.Index}")), + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + IdentifierName($"ABI_{declaredMethodContext.GenerationContext.StubMethodSyntaxTemplate.Identifier}"))))); + } return ImplementationInterfaceTemplate .AddMembers( - CreateManagedVirtualFunctionTableMethodTemplate + FieldDeclaration( + VariableDeclaration(ParseTypeName("InterfaceImplementationVtable")) + .AddVariables(VariableDeclarator("Vtable"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword))) + .AddAttributeLists( + AttributeList( + SingletonSeparatedList( + Attribute(NameSyntaxes.System_Runtime_CompilerServices_FixedAddressValueTypeAttribute)))), + ConstructorDeclaration("InterfaceImplementation") + .AddModifiers(Token(SyntaxKind.StaticKeyword)) .WithBody( Block( - vtableDeclarationStatement, fillBaseInterfaceSlots, - vtableSlotAssignments, - ReturnStatement(IdentifierName(vtableLocalName))))); + Block(statements)))); } private static readonly ClassDeclarationSyntax InterfaceInformationTypeTemplate = @@ -677,27 +831,25 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceI if (context.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper)) { - const string vtableFieldName = "_vtable"; return interfaceInformationType.AddMembers( - // private static void** _vtable; - FieldDeclaration(VariableDeclaration(TypeSyntaxes.VoidStarStar, SingletonSeparatedList(VariableDeclarator(vtableFieldName)))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword)), - // public static void* VirtualMethodTableManagedImplementation => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualMethodTable()); + // public static void** VirtualMethodTableManagedImplementation => (void**)System.Runtime.CompilerServices.Unsafe.AsPointer(in InterfaceImplementation.Vtable); PropertyDeclaration(TypeSyntaxes.VoidStarStar, "ManagedVirtualMethodTable") .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) .WithExpressionBody( ArrowExpressionClause( - ConditionalExpression( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(vtableFieldName), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - IdentifierName(vtableFieldName), - ParenthesizedExpression( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(vtableFieldName), - MethodInvocation( + CastExpression( + PointerType(PointerType(ParseTypeName("void"))), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + TypeSyntaxes.System_Runtime_CompilerServices_Unsafe, + IdentifierName("AsPointer"))) + .AddArgumentListArguments( + InArgument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, IdentifierName("InterfaceImplementation"), - IdentifierName(CreateManagedVirtualFunctionTableMethodName))))))) + IdentifierName("Vtable"))))))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs index 880022431c0276..a56644cef13f58 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs @@ -257,7 +257,7 @@ public static BlockSyntax GenerateVirtualMethodTableSlotAssignments( return Block(statements); } - private static FunctionPointerTypeSyntax GenerateUnmanagedFunctionPointerTypeForMethod( + public static FunctionPointerTypeSyntax GenerateUnmanagedFunctionPointerTypeForMethod( IncrementalMethodStubGenerationContext method, Func generatorResolverCreator) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index 2fe9b606b21ea6..1b03a386163eda 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -42,6 +42,12 @@ public static class NameSyntaxes private static NameSyntax? _WasmImportLinkageAttribute; public static NameSyntax WasmImportLinkageAttribute => _WasmImportLinkageAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.WasmImportLinkageAttribute); + + private static NameSyntax? _System_Runtime_CompilerServices_FixedAddressValueTypeAttribute; + public static NameSyntax System_Runtime_CompilerServices_FixedAddressValueTypeAttribute => _System_Runtime_CompilerServices_FixedAddressValueTypeAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Runtime_CompilerServices_FixedAddressValueTypeAttribute); + + private static NameSyntax? _System_Runtime_InteropServices_StructLayoutAttribute; + public static NameSyntax System_Runtime_InteropServices_StructLayoutAttribute => _System_Runtime_InteropServices_StructLayoutAttribute ??= ParseName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_StructLayoutAttribute); } public static class TypeSyntaxes @@ -126,6 +132,9 @@ public static class TypeSyntaxes private static TypeSyntax? _System_Runtime_CompilerServices_Unsafe; public static TypeSyntax System_Runtime_CompilerServices_Unsafe => _System_Runtime_CompilerServices_Unsafe ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_CompilerServices_Unsafe); + private static TypeSyntax? _System_Runtime_InteropServices_LayoutKind; + public static TypeSyntax System_Runtime_InteropServices_LayoutKind => _System_Runtime_InteropServices_LayoutKind ??= ParseTypeName(TypeNames.GlobalAlias + TypeNames.System_Runtime_InteropServices_LayoutKind); + private static TypeSyntax? _CallConvCdecl; private static TypeSyntax? _CallConvFastcall; private static TypeSyntax? _CallConvMemberFunction; @@ -246,6 +255,8 @@ public static string MarshalEx(InteropGenerationOptions options) public const string System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute = "System.Runtime.CompilerServices.DisableRuntimeMarshallingAttribute"; + public const string System_Runtime_CompilerServices_FixedAddressValueTypeAttribute = "System.Runtime.CompilerServices.FixedAddressValueTypeAttribute"; + public const string DefaultDllImportSearchPathsAttribute = "System.Runtime.InteropServices.DefaultDllImportSearchPathsAttribute"; public const string DllImportSearchPath = "System.Runtime.InteropServices.DllImportSearchPath"; @@ -307,6 +318,8 @@ public static string MarshalEx(InteropGenerationOptions options) public const string System_Runtime_InteropServices_NFloat = "System.Runtime.InteropServices.NFloat"; + public const string System_Runtime_InteropServices_LayoutKind = "System.Runtime.InteropServices.LayoutKind"; + public const string CallConvCdeclName = "System.Runtime.CompilerServices.CallConvCdecl"; public const string CallConvFastcallName = "System.Runtime.CompilerServices.CallConvFastcall"; public const string CallConvStdcallName = "System.Runtime.CompilerServices.CallConvStdcall"; diff --git a/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/CCWPreinitializationNativeAot.cs b/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/CCWPreinitializationNativeAot.cs new file mode 100644 index 00000000000000..ba15424c8a5d96 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/CCWPreinitializationNativeAot.cs @@ -0,0 +1,38 @@ +using System; +using System.Reflection; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +var comTypeInfo = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(IComInterface).TypeHandle)!; +Type cwType = comTypeInfo.Implementation; +unsafe +{ + nint* vtable = (nint*)comTypeInfo.ManagedVirtualMethodTable; + + if (HasCctor(cwType)) + return -1; + + ComWrappers.GetIUnknownImpl( + out nint queryInterface, + out nint addRef, + out nint release); + if (vtable[0] != queryInterface || vtable[1] != addRef || vtable[2] != release) + return -2; +} + +return 100; + +[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern", + Justification = "Yep, we don't want to keep the cctor if it wasn't kept")] +static bool HasCctor(Type type) +{ + return type.GetConstructor(BindingFlags.NonPublic | BindingFlags.Static, null, Type.EmptyTypes, null) != null; +} + +[GeneratedComInterface] +[Guid("ad358058-2b72-4801-8d98-043d44dc42c4")] +partial interface IComInterface +{ + int Method(); +} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/System.Runtime.InteropServices.TrimmingTests.proj b/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/System.Runtime.InteropServices.TrimmingTests.proj index a812b392537fad..daf43dbb12a7b1 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/System.Runtime.InteropServices.TrimmingTests.proj +++ b/src/libraries/System.Runtime.InteropServices/tests/TrimmingTests/System.Runtime.InteropServices.TrimmingTests.proj @@ -17,6 +17,11 @@ win-x64;browser-wasm + + + osx-x64;linux-x64;browser-wasm + IlcTrimMetadata +