diff --git a/src/coreclr/vm/interopconverter.cpp b/src/coreclr/vm/interopconverter.cpp index d74eb6c780fbd0..fa49b2f4179125 100644 --- a/src/coreclr/vm/interopconverter.cpp +++ b/src/coreclr/vm/interopconverter.cpp @@ -188,18 +188,29 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (ReqIpType & ComIpType_Dispatch) { hr = SafeQueryInterface(pUnk, IID_IDispatch, &pvObj); - pUnk->Release(); + if (SUCCEEDED(hr)) + { + pUnk->Release(); + FetchedIpType = ComIpType_Dispatch; + } + else if (ReqIpType & ComIpType_Unknown) + { + hr = S_OK; + pvObj = pUnk; + FetchedIpType = ComIpType_Unknown; + } } else { pvObj = pUnk; + FetchedIpType = ComIpType_Unknown; } if (FAILED(hr)) COMPlusThrowHR(hr); if (pFetchedIpType != NULL) - *pFetchedIpType = ReqIpType; + *pFetchedIpType = FetchedIpType; RETURN pvObj; } diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs new file mode 100644 index 00000000000000..a589fd1da7e545 --- /dev/null +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.BuiltInCom.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using TestLibrary; +using static VariantNative; + +#pragma warning disable CS0612, CS0618 +partial class Test +{ + public static int Main() + { + bool builtInComDisabled=false; + var comConfig = AppContext.GetData("System.Runtime.InteropServices.BuiltInComInterop.IsSupported"); + if(comConfig != null && !bool.Parse(comConfig.ToString())) + { + builtInComDisabled=true; + } + + Console.WriteLine($"Built-in COM Disabled?: {builtInComDisabled}"); + try + { + TestByValue(!builtInComDisabled); + TestByRef(!builtInComDisabled); + TestOut(); + TestFieldByValue(!builtInComDisabled); + TestFieldByRef(!builtInComDisabled); + } + catch (Exception e) + { + Console.WriteLine($"Test failed: {e}"); + return 101; + } + return 100; + } +} diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs new file mode 100644 index 00000000000000..a4dcc422cbd7dd --- /dev/null +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.ComWrappers.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using TestLibrary; +using static VariantNative; +using ComTypes = System.Runtime.InteropServices.ComTypes; + +#pragma warning disable CS0612, CS0618 +partial class Test +{ + public static int Main() + { + bool testComMarshal=true; + ComWrappers.RegisterForMarshalling(new ComWrappersImpl()); + try + { + TestByValue(testComMarshal); + TestByRef(testComMarshal); + TestOut(); + TestFieldByValue(testComMarshal); + TestFieldByRef(testComMarshal); + } + catch (Exception e) + { + Console.WriteLine($"Test failed: {e}"); + return 101; + } + return 100; + } +} + +internal unsafe class ComWrappersImpl : ComWrappers +{ + private static readonly ComInterfaceEntry* wrapperEntry; + + static ComWrappersImpl() + { + var vtblRaw = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IDispatchVtbl), sizeof(IntPtr) * 7); + GetIUnknownImpl(out vtblRaw[0], out vtblRaw[1], out vtblRaw[2]); + + vtblRaw[3] = (IntPtr)(delegate* unmanaged)&IDispatchVtbl.GetTypeInfoCountInternal; + vtblRaw[4] = (IntPtr)(delegate* unmanaged)&IDispatchVtbl.GetTypeInfoInternal; + vtblRaw[5] = (IntPtr)(delegate* unmanaged)&IDispatchVtbl.GetIDsOfNamesInternal; + vtblRaw[6] = (IntPtr)(delegate* unmanaged)&IDispatchVtbl.InvokeInternal; + + wrapperEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IDispatchVtbl), sizeof(ComInterfaceEntry)); + wrapperEntry->IID = IDispatchVtbl.IID_IDispatch; + wrapperEntry->Vtable = (IntPtr)vtblRaw; + } + + protected override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + // Always return the same table mappings. + count = 1; + return wrapperEntry; + } + + protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flags) + { + throw new NotImplementedException(); + } + + protected override void ReleaseObjects(IEnumerable objects) + { + throw new NotImplementedException(); + } +} +public struct IDispatchVtbl +{ + internal static readonly Guid IID_IDispatch = new Guid("00020400-0000-0000-C000-000000000046"); + + [UnmanagedCallersOnly] + public static int GetTypeInfoCountInternal(IntPtr thisPtr, IntPtr i) + { + return 0; // S_OK; + } + + [UnmanagedCallersOnly] + public static int GetTypeInfoInternal(IntPtr thisPtr, int itinfo, int lcid, IntPtr i) + { + return 0; // S_OK; + } + + [UnmanagedCallersOnly] + public static int GetIDsOfNamesInternal( + IntPtr thisPtr, + IntPtr iid, + IntPtr names, + int namesCount, + int lcid, + IntPtr dispIds) + { + return 0; // S_OK; + } + + [UnmanagedCallersOnly] + public static int InvokeInternal( + IntPtr thisPtr, + int dispIdMember, + IntPtr riid, + int lcid, + ComTypes.INVOKEKIND wFlags, + IntPtr pDispParams, + IntPtr VarResult, + IntPtr pExcepInfo, + IntPtr puArgErr) + { + return 0; // S_OK; + } +} diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.cs b/src/tests/Interop/PInvoke/Variant/VariantTest.cs index f4214d4d0d89e9..e15814a9cb6e0c 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantTest.cs +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.cs @@ -7,7 +7,7 @@ using static VariantNative; #pragma warning disable CS0612, CS0618 -class Test +partial class Test { private const byte NumericValue = 15; @@ -19,7 +19,7 @@ class Test private static readonly DateTime DateValue = new DateTime(2018, 11, 6); - private unsafe static void TestByValue() + private unsafe static void TestByValue(bool hasComSupport) { Assert.IsTrue(Marshal_ByValue_Byte((byte)NumericValue, NumericValue)); Assert.IsTrue(Marshal_ByValue_SByte((sbyte)NumericValue, (sbyte)NumericValue)); @@ -41,82 +41,89 @@ private unsafe static void TestByValue() Assert.IsTrue(Marshal_ByValue_Null(DBNull.Value)); Assert.IsTrue(Marshal_ByValue_Missing(System.Reflection.Missing.Value)); Assert.IsTrue(Marshal_ByValue_Empty(null)); - Assert.IsTrue(Marshal_ByValue_Object(new object())); - Assert.IsTrue(Marshal_ByValue_Object_IUnknown(new UnknownWrapper(new object()))); + if (hasComSupport) + { + Assert.IsTrue(Marshal_ByValue_Object(new object())); + Assert.IsTrue(Marshal_ByValue_Object_IUnknown(new UnknownWrapper(new object()))); + } + Assert.Throws(() => Marshal_ByValue_Invalid(TimeSpan.Zero)); Assert.Throws(() => Marshal_ByValue_Invalid(new CustomStruct())); Assert.Throws(() => Marshal_ByValue_Invalid(new VariantWrapper(CharValue))); } - private unsafe static void TestByRef() + private unsafe static void TestByRef(bool hasComSupport) { object obj; obj = (byte)NumericValue; Assert.IsTrue(Marshal_ByRef_Byte(ref obj, NumericValue)); - + obj = (sbyte)NumericValue; Assert.IsTrue(Marshal_ByRef_SByte(ref obj, (sbyte)NumericValue)); - + obj = (short)NumericValue; Assert.IsTrue(Marshal_ByRef_Int16(ref obj, NumericValue)); - + obj = (ushort)NumericValue; Assert.IsTrue(Marshal_ByRef_UInt16(ref obj, NumericValue)); - + obj = (int)NumericValue; Assert.IsTrue(Marshal_ByRef_Int32(ref obj, NumericValue)); - + obj = (uint)NumericValue; Assert.IsTrue(Marshal_ByRef_UInt32(ref obj, NumericValue)); - + obj = (long)NumericValue; Assert.IsTrue(Marshal_ByRef_Int64(ref obj, NumericValue)); - + obj = (ulong)NumericValue; Assert.IsTrue(Marshal_ByRef_UInt64(ref obj, NumericValue)); - + obj = (float)NumericValue; Assert.IsTrue(Marshal_ByRef_Single(ref obj, NumericValue)); - + obj = (double)NumericValue; Assert.IsTrue(Marshal_ByRef_Double(ref obj, NumericValue)); - + obj = StringValue; Assert.IsTrue(Marshal_ByRef_String(ref obj, StringValue)); obj = new BStrWrapper(null); Assert.IsTrue(Marshal_ByRef_String(ref obj, null)); - + obj = CharValue; Assert.IsTrue(Marshal_ByRef_Char(ref obj, CharValue)); - + obj = true; Assert.IsTrue(Marshal_ByRef_Boolean(ref obj, true)); - + obj = DateValue; Assert.IsTrue(Marshal_ByRef_DateTime(ref obj, DateValue)); - + obj = DecimalValue; Assert.IsTrue(Marshal_ByRef_Decimal(ref obj, DecimalValue)); obj = new CurrencyWrapper(DecimalValue); Assert.IsTrue(Marshal_ByRef_Currency(ref obj, DecimalValue)); - + obj = DBNull.Value; Assert.IsTrue(Marshal_ByRef_Null(ref obj)); - + obj = System.Reflection.Missing.Value; Assert.IsTrue(Marshal_ByRef_Missing(ref obj)); - + obj = null; Assert.IsTrue(Marshal_ByRef_Empty(ref obj)); - - obj = new object(); - Assert.IsTrue(Marshal_ByRef_Object(ref obj)); - obj = new UnknownWrapper(new object()); - Assert.IsTrue(Marshal_ByRef_Object_IUnknown(ref obj)); + if (hasComSupport) + { + obj = new object(); + Assert.IsTrue(Marshal_ByRef_Object(ref obj)); + + obj = new UnknownWrapper(new object()); + Assert.IsTrue(Marshal_ByRef_Object_IUnknown(ref obj)); + } obj = DecimalValue; Assert.IsTrue(Marshal_ChangeVariantType(ref obj, NumericValue)); @@ -130,164 +137,152 @@ private unsafe static void TestOut() Assert.IsTrue(obj is int); Assert.AreEqual(NumericValue, (int)obj); } - - private unsafe static void TestFieldByValue() + + private unsafe static void TestFieldByValue(bool hasComSupport) { ObjectWrapper wrapper = new ObjectWrapper(); wrapper.value = (byte)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Byte(wrapper, NumericValue)); - + wrapper.value = (sbyte)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_SByte(wrapper, (sbyte)NumericValue)); - + wrapper.value = (short)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Int16(wrapper, NumericValue)); - + wrapper.value = (ushort)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_UInt16(wrapper, NumericValue)); - + wrapper.value = (int)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Int32(wrapper, NumericValue)); - + wrapper.value = (uint)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_UInt32(wrapper, NumericValue)); - + wrapper.value = (long)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Int64(wrapper, NumericValue)); - + wrapper.value = (ulong)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_UInt64(wrapper, NumericValue)); - + wrapper.value = (float)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Single(wrapper, NumericValue)); - + wrapper.value = (double)NumericValue; Assert.IsTrue(Marshal_Struct_ByValue_Double(wrapper, NumericValue)); - + wrapper.value = StringValue; Assert.IsTrue(Marshal_Struct_ByValue_String(wrapper, StringValue)); wrapper.value = new BStrWrapper(null); Assert.IsTrue(Marshal_Struct_ByValue_String(wrapper, null)); - + wrapper.value = CharValue; Assert.IsTrue(Marshal_Struct_ByValue_Char(wrapper, CharValue)); - + wrapper.value = true; Assert.IsTrue(Marshal_Struct_ByValue_Boolean(wrapper, true)); - + wrapper.value = DateValue; Assert.IsTrue(Marshal_Struct_ByValue_DateTime(wrapper, DateValue)); - + wrapper.value = DecimalValue; Assert.IsTrue(Marshal_Struct_ByValue_Decimal(wrapper, DecimalValue)); wrapper.value = new CurrencyWrapper(DecimalValue); Assert.IsTrue(Marshal_Struct_ByValue_Currency(wrapper, DecimalValue)); - + wrapper.value = DBNull.Value; Assert.IsTrue(Marshal_Struct_ByValue_Null(wrapper)); - + wrapper.value = System.Reflection.Missing.Value; Assert.IsTrue(Marshal_Struct_ByValue_Missing(wrapper)); - + wrapper.value = null; Assert.IsTrue(Marshal_Struct_ByValue_Empty(wrapper)); - - wrapper.value = new object(); - Assert.IsTrue(Marshal_Struct_ByValue_Object(wrapper)); - - wrapper.value = new UnknownWrapper(new object()); - Assert.IsTrue(Marshal_Struct_ByValue_Object_IUnknown(wrapper)); + + if (hasComSupport) + { + wrapper.value = new object(); + Assert.IsTrue(Marshal_Struct_ByValue_Object(wrapper)); + + wrapper.value = new UnknownWrapper(new object()); + Assert.IsTrue(Marshal_Struct_ByValue_Object_IUnknown(wrapper)); + } } - private unsafe static void TestFieldByRef() + private unsafe static void TestFieldByRef(bool hasComSupport) { ObjectWrapper wrapper = new ObjectWrapper(); wrapper.value = (byte)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Byte(ref wrapper, NumericValue)); - + wrapper.value = (sbyte)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_SByte(ref wrapper, (sbyte)NumericValue)); - + wrapper.value = (short)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Int16(ref wrapper, NumericValue)); - + wrapper.value = (ushort)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_UInt16(ref wrapper, NumericValue)); - + wrapper.value = (int)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Int32(ref wrapper, NumericValue)); - + wrapper.value = (uint)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_UInt32(ref wrapper, NumericValue)); - + wrapper.value = (long)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Int64(ref wrapper, NumericValue)); - + wrapper.value = (ulong)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_UInt64(ref wrapper, NumericValue)); - + wrapper.value = (float)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Single(ref wrapper, NumericValue)); - + wrapper.value = (double)NumericValue; Assert.IsTrue(Marshal_Struct_ByRef_Double(ref wrapper, NumericValue)); - + wrapper.value = StringValue; Assert.IsTrue(Marshal_Struct_ByRef_String(ref wrapper, StringValue)); wrapper.value = new BStrWrapper(null); Assert.IsTrue(Marshal_Struct_ByRef_String(ref wrapper, null)); - + wrapper.value = CharValue; Assert.IsTrue(Marshal_Struct_ByRef_Char(ref wrapper, CharValue)); - + wrapper.value = true; Assert.IsTrue(Marshal_Struct_ByRef_Boolean(ref wrapper, true)); - + wrapper.value = DateValue; Assert.IsTrue(Marshal_Struct_ByRef_DateTime(ref wrapper, DateValue)); - + wrapper.value = DecimalValue; Assert.IsTrue(Marshal_Struct_ByRef_Decimal(ref wrapper, DecimalValue)); - + wrapper.value = new CurrencyWrapper(DecimalValue); Assert.IsTrue(Marshal_Struct_ByRef_Currency(ref wrapper, DecimalValue)); - + wrapper.value = DBNull.Value; Assert.IsTrue(Marshal_Struct_ByRef_Null(ref wrapper)); - + wrapper.value = System.Reflection.Missing.Value; Assert.IsTrue(Marshal_Struct_ByRef_Missing(ref wrapper)); - + wrapper.value = null; Assert.IsTrue(Marshal_Struct_ByRef_Empty(ref wrapper)); - - wrapper.value = new object(); - Assert.IsTrue(Marshal_Struct_ByRef_Object(ref wrapper)); - - wrapper.value = new UnknownWrapper(new object()); - Assert.IsTrue(Marshal_Struct_ByRef_Object_IUnknown(ref wrapper)); - } - public static int Main() - { - try - { - TestByValue(); - TestByRef(); - TestOut(); - TestFieldByValue(); - TestFieldByRef(); - } - catch (Exception e) + if (hasComSupport) { - Console.WriteLine($"Test failed: {e}"); - return 101; + wrapper.value = new object(); + Assert.IsTrue(Marshal_Struct_ByRef_Object(ref wrapper)); + + wrapper.value = new UnknownWrapper(new object()); + Assert.IsTrue(Marshal_Struct_ByRef_Object_IUnknown(ref wrapper)); } - return 100; } } diff --git a/src/tests/Interop/PInvoke/Variant/VariantTest.csproj b/src/tests/Interop/PInvoke/Variant/VariantTest.csproj index e1b97c6c5c301a..13d0f08d40a065 100644 --- a/src/tests/Interop/PInvoke/Variant/VariantTest.csproj +++ b/src/tests/Interop/PInvoke/Variant/VariantTest.csproj @@ -6,7 +6,9 @@ true - + + + diff --git a/src/tests/Interop/PInvoke/Variant/VariantTestBuiltInComDisabled.csproj b/src/tests/Interop/PInvoke/Variant/VariantTestBuiltInComDisabled.csproj new file mode 100644 index 00000000000000..87612cc3fe5d2d --- /dev/null +++ b/src/tests/Interop/PInvoke/Variant/VariantTestBuiltInComDisabled.csproj @@ -0,0 +1,19 @@ + + + Exe + true + + true + + + + + + + + + + + + + diff --git a/src/tests/Interop/PInvoke/Variant/VariantTestComWrappers.csproj b/src/tests/Interop/PInvoke/Variant/VariantTestComWrappers.csproj new file mode 100644 index 00000000000000..254ec39cf21d27 --- /dev/null +++ b/src/tests/Interop/PInvoke/Variant/VariantTestComWrappers.csproj @@ -0,0 +1,19 @@ + + + Exe + true + + true + + + + + + + + + + + + +