From 43e49cfa1bcce2e5399942fd3ab77cfc7b3fcbd1 Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Tue, 17 Aug 2021 13:11:54 -0600 Subject: [PATCH 1/2] Fix BOOL-bool conversion to not be lossy --- src/Microsoft.Windows.CsWin32/Generator.cs | 32 ++++++++++++++----- test/GenerationSandbox.Tests/BasicTests.cs | 32 +++++++++++++++++++ .../GeneratorTests.cs | 24 ++++++++++---- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 1acb5dd1..1f2d75c9 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -3471,8 +3471,8 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef) FieldDefinition fieldDef = this.Reader.GetFieldDefinition(typeDef.GetFields().Single()); var fieldAttributes = fieldDef.GetCustomAttributes(); - string fieldName = this.Reader.GetString(fieldDef.Name); - VariableDeclaratorSyntax fieldDeclarator = VariableDeclarator(Identifier("value")); + IdentifierNameSyntax fieldName = IdentifierName("value"); + VariableDeclaratorSyntax fieldDeclarator = VariableDeclarator(fieldName.Identifier); (TypeSyntax FieldType, SyntaxList AdditionalMembers) fieldInfo = this.ReinterpretFieldType(fieldDef, fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null).ToTypeSyntax(this.fieldTypeSettings, fieldAttributes).Type, fieldAttributes); SyntaxList members = List(); @@ -3488,9 +3488,17 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef) .WithExpressionBody(ArrowExpressionClause(fieldAccessExpression)).WithSemicolonToken(SemicolonWithLineFeed) .AddModifiers(TokenWithSpace(this.Visibility))); - // BOOL(bool value) => this.value = value ? 1 : 0; + static InvocationExpressionSyntax UnsafeAs(SyntaxKind fromType, SyntaxKind toType, IdentifierNameSyntax localSource) => + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nameof(Unsafe)), + GenericName(nameof(Unsafe.As), TypeArgumentList().AddArguments(PredefinedType(Token(fromType)), PredefinedType(Token(toType))))), + ArgumentList().AddArguments(Argument(localSource).WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))); + + // BOOL(bool value) => this.value = Unsafe.As(ref value); IdentifierNameSyntax valueParameter = IdentifierName("value"); - ExpressionSyntax boolToInt = ConditionalExpression(valueParameter, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))); + ExpressionSyntax boolToInt = UnsafeAs(SyntaxKind.BoolKeyword, SyntaxKind.SByteKeyword, valueParameter); members = members.Add(ConstructorDeclaration(name.Identifier) .AddModifiers(TokenWithSpace(this.Visibility)) .AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)))) @@ -3504,12 +3512,20 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef) .WithExpressionBody(ArrowExpressionClause(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, fieldAccessExpression, valueParameter).WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))) .WithSemicolonToken(SemicolonWithLineFeed)); - // public static implicit operator bool(BOOL value) => value.value != 0 ? true : false; + // public static implicit operator bool(BOOL value) + // { + // sbyte v = checked((sbyte)value.value); + // return Unsafe.As(ref v); + // } + IdentifierNameSyntax localVarName = IdentifierName("v"); + var implicitBOOLtoBoolBody = Block().AddStatements( + LocalDeclarationStatement(VariableDeclaration(PredefinedType(Token(SyntaxKind.SByteKeyword)))).AddDeclarationVariables( + VariableDeclarator(localVarName.Identifier).WithInitializer(EqualsValueClause(CheckedExpression(SyntaxKind.CheckedExpression, CastExpression(PredefinedType(Token(SyntaxKind.SByteKeyword)), MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, fieldName)))))), + ReturnStatement(UnsafeAs(SyntaxKind.SByteKeyword, SyntaxKind.BoolKeyword, localVarName))); members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ImplicitKeyword), PredefinedType(Token(SyntaxKind.BoolKeyword))) .AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(name.WithTrailingTrivia(TriviaList(Space)))) - .WithExpressionBody(ArrowExpressionClause(BinaryExpression(SyntaxKind.NotEqualsExpression, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, IdentifierName(fieldName)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))))) - .AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword)) // operators MUST be public - .WithSemicolonToken(SemicolonWithLineFeed)); + .WithBody(implicitBOOLtoBoolBody) + .AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword))); // operators MUST be public // public static implicit operator BOOL(bool value) => new BOOL(value); members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ImplicitKeyword), name) diff --git a/test/GenerationSandbox.Tests/BasicTests.cs b/test/GenerationSandbox.Tests/BasicTests.cs index 58eed729..5c453103 100644 --- a/test/GenerationSandbox.Tests/BasicTests.cs +++ b/test/GenerationSandbox.Tests/BasicTests.cs @@ -4,6 +4,7 @@ using System; using System.ComponentModel; using System.IO; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using Windows.Win32; using Windows.Win32.Foundation; @@ -51,6 +52,37 @@ public void Bool() Assert.False(default(BOOL)); } + [Theory] + [InlineData(3)] + [InlineData(-1)] + public void NotLossyConversionBetweenBoolAndBOOL(int ordinal) + { + BOOL nativeBool = new BOOL(ordinal); + bool managedBool = nativeBool; + BOOL roundtrippedNativeBool = managedBool; + Assert.Equal(nativeBool, roundtrippedNativeBool); + } + + [Theory] + [InlineData(3)] + [InlineData(-1)] + public void NotLossyConversionBetweenBoolAndBOOL_Ctors(int ordinal) + { + BOOL nativeBool = new BOOL(ordinal); + bool managedBool = nativeBool; + BOOL roundtrippedNativeBool = new BOOL(managedBool); + Assert.Equal(nativeBool, roundtrippedNativeBool); + } + + [Fact] + public void BOOLEqualsComparesExactValue() + { + BOOL b1 = new BOOL(1); + BOOL b2 = new BOOL(2); + Assert.Equal(b1, b1); + Assert.NotEqual(b1, b2); + } + [Fact] public void BSTR_ToString() { diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index 1fdd897b..c5398a5d 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -861,9 +861,13 @@ internal readonly partial struct BOOL private readonly int value; internal int Value => this.value; - internal BOOL(bool value) => this.value = value ? 1 : 0; + internal BOOL(bool value) => this.value = Unsafe.As(ref value); internal BOOL(int value) => this.value = value; - public static implicit operator bool(BOOL value) => value.Value != 0; + public static implicit operator bool(BOOL value) + { + sbyte v = checked((sbyte)value.value); + return Unsafe.As(ref v); + } public static implicit operator BOOL(bool value) => new BOOL(value); public static explicit operator BOOL(int value) => new BOOL(value); } @@ -1140,9 +1144,13 @@ internal readonly partial struct BOOL private readonly int value; internal int Value => this.value; - internal BOOL(bool value) => this.value = value ? 1 : 0; + internal BOOL(bool value) => this.value = Unsafe.As(ref value); internal BOOL(int value) => this.value = value; - public static implicit operator bool(BOOL value) => value.Value != 0; + public static implicit operator bool(BOOL value) + { + sbyte v = checked((sbyte)value.value); + return Unsafe.As(ref v); + } public static implicit operator BOOL(bool value) => new BOOL(value); public static explicit operator BOOL(int value) => new BOOL(value); } @@ -1460,9 +1468,13 @@ internal readonly partial struct BOOL private readonly int value; internal int Value => this.value; - internal BOOL(bool value) => this.value = value ? 1 : 0; + internal BOOL(bool value) => this.value = Unsafe.As(ref value); internal BOOL(int value) => this.value = value; - public static implicit operator bool(BOOL value) => value.Value != 0; + public static implicit operator bool(BOOL value) + { + sbyte v = checked((sbyte)value.value); + return Unsafe.As(ref v); + } public static implicit operator BOOL(bool value) => new BOOL(value); public static explicit operator BOOL(int value) => new BOOL(value); } From 65fec5957dabe8e55b0f28051c7cd9e9954cc125 Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Tue, 17 Aug 2021 13:30:00 -0600 Subject: [PATCH 2/2] Return `BOOL` instead of `bool` from methods Fixes #362 --- src/Microsoft.Windows.CsWin32/Generator.cs | 4 ++-- .../Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 1f2d75c9..80e27de9 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -51,7 +51,7 @@ public class Generator : IDisposable internal static readonly Dictionary AdditionalBclInteropStructsMarshaled = new Dictionary(StringComparer.Ordinal) { - { "BOOL", PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)) }, + ////{ "BOOL", PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)) }, }; internal static readonly Dictionary BclInteropSafeHandles = new Dictionary(StringComparer.Ordinal) @@ -1330,7 +1330,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle) ArgumentList().AddArguments(Argument(CastExpression(releaseMethodParameterType.Type, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle")))))); BlockSyntax? releaseBlock = null; if (!(releaseMethodReturnType.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.BoolKeyword } } || - releaseMethodReturnType.Type is IdentifierNameSyntax { Identifier: { ValueText: "BOOL" } })) + releaseMethodReturnType.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "BOOL" } } })) { switch (releaseMethodReturnType.Type) { diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index c5398a5d..234b2141 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -329,16 +329,17 @@ public void CreateFileUsesSafeHandles() && createFileMethod.ParameterList.Parameters.Last().Type?.ToString() == "SafeHandle"); } + /// + /// GetMessage should return BOOL rather than bool because it actually returns any of THREE values. + /// [Fact] - public void BOOL_ReturnTypeBecomes_Boolean() + public void GetMessageW_ReturnsBOOL() { this.generator = this.CreateGenerator(); - Assert.True(this.generator.TryGenerate("WinUsb_FlushPipe", CancellationToken.None)); + Assert.True(this.generator.TryGenerate("GetMessage", CancellationToken.None)); this.CollectGeneratedCode(this.generator); this.AssertNoDiagnostics(); - MethodDeclarationSyntax? createFileMethod = this.FindGeneratedMethod("WinUsb_FlushPipe").FirstOrDefault(); - Assert.NotNull(createFileMethod); - Assert.Equal(SyntaxKind.BoolKeyword, Assert.IsType(createFileMethod!.ReturnType).Keyword.Kind()); + Assert.All(this.FindGeneratedMethod("GetMessage"), method => Assert.True(method.ReturnType is QualifiedNameSyntax { Right: { Identifier: { ValueText: "BOOL" } } })); } [Theory, PairwiseData]