Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class Generator : IDisposable

internal static readonly Dictionary<string, TypeSyntax> AdditionalBclInteropStructsMarshaled = new Dictionary<string, TypeSyntax>(StringComparer.Ordinal)
{
{ "BOOL", PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)) },
////{ "BOOL", PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)) },
};

internal static readonly Dictionary<string, TypeSyntax> BclInteropSafeHandles = new Dictionary<string, TypeSyntax>(StringComparer.Ordinal)
Expand Down Expand Up @@ -1330,7 +1330,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle)
ArgumentList().AddArguments(Argument(CastExpression(releaseMethodParameterType.Type, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle"))))));
BlockSyntax? releaseBlock = null;
if (!(releaseMethodReturnType.Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.BoolKeyword } } ||
releaseMethodReturnType.Type is IdentifierNameSyntax { Identifier: { ValueText: "BOOL" } }))
releaseMethodReturnType.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "BOOL" } } }))
{
switch (releaseMethodReturnType.Type)
{
Expand Down Expand Up @@ -3471,8 +3471,8 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef)

FieldDefinition fieldDef = this.Reader.GetFieldDefinition(typeDef.GetFields().Single());
var fieldAttributes = fieldDef.GetCustomAttributes();
string fieldName = this.Reader.GetString(fieldDef.Name);
VariableDeclaratorSyntax fieldDeclarator = VariableDeclarator(Identifier("value"));
IdentifierNameSyntax fieldName = IdentifierName("value");
VariableDeclaratorSyntax fieldDeclarator = VariableDeclarator(fieldName.Identifier);
(TypeSyntax FieldType, SyntaxList<MemberDeclarationSyntax> AdditionalMembers) fieldInfo =
this.ReinterpretFieldType(fieldDef, fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null).ToTypeSyntax(this.fieldTypeSettings, fieldAttributes).Type, fieldAttributes);
SyntaxList<MemberDeclarationSyntax> members = List<MemberDeclarationSyntax>();
Expand All @@ -3488,9 +3488,17 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef)
.WithExpressionBody(ArrowExpressionClause(fieldAccessExpression)).WithSemicolonToken(SemicolonWithLineFeed)
.AddModifiers(TokenWithSpace(this.Visibility)));

// BOOL(bool value) => this.value = value ? 1 : 0;
static InvocationExpressionSyntax UnsafeAs(SyntaxKind fromType, SyntaxKind toType, IdentifierNameSyntax localSource) =>
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nameof(Unsafe)),
GenericName(nameof(Unsafe.As), TypeArgumentList().AddArguments(PredefinedType(Token(fromType)), PredefinedType(Token(toType))))),
ArgumentList().AddArguments(Argument(localSource).WithRefKindKeyword(Token(SyntaxKind.RefKeyword))));

// BOOL(bool value) => this.value = Unsafe.As<bool, sbyte>(ref value);
IdentifierNameSyntax valueParameter = IdentifierName("value");
ExpressionSyntax boolToInt = ConditionalExpression(valueParameter, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)));
ExpressionSyntax boolToInt = UnsafeAs(SyntaxKind.BoolKeyword, SyntaxKind.SByteKeyword, valueParameter);
members = members.Add(ConstructorDeclaration(name.Identifier)
.AddModifiers(TokenWithSpace(this.Visibility))
.AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword))))
Expand All @@ -3504,12 +3512,20 @@ private StructDeclarationSyntax DeclareTypeDefBOOLStruct(TypeDefinition typeDef)
.WithExpressionBody(ArrowExpressionClause(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, fieldAccessExpression, valueParameter).WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken))))
.WithSemicolonToken(SemicolonWithLineFeed));

// public static implicit operator bool(BOOL value) => value.value != 0 ? true : false;
// public static implicit operator bool(BOOL value)
// {
// sbyte v = checked((sbyte)value.value);
// return Unsafe.As<sbyte, bool>(ref v);
// }
IdentifierNameSyntax localVarName = IdentifierName("v");
var implicitBOOLtoBoolBody = Block().AddStatements(
LocalDeclarationStatement(VariableDeclaration(PredefinedType(Token(SyntaxKind.SByteKeyword)))).AddDeclarationVariables(
VariableDeclarator(localVarName.Identifier).WithInitializer(EqualsValueClause(CheckedExpression(SyntaxKind.CheckedExpression, CastExpression(PredefinedType(Token(SyntaxKind.SByteKeyword)), MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, fieldName)))))),
ReturnStatement(UnsafeAs(SyntaxKind.SByteKeyword, SyntaxKind.BoolKeyword, localVarName)));
members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ImplicitKeyword), PredefinedType(Token(SyntaxKind.BoolKeyword)))
.AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(name.WithTrailingTrivia(TriviaList(Space))))
.WithExpressionBody(ArrowExpressionClause(BinaryExpression(SyntaxKind.NotEqualsExpression, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, IdentifierName(fieldName)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))))
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword)) // operators MUST be public
.WithSemicolonToken(SemicolonWithLineFeed));
.WithBody(implicitBOOLtoBoolBody)
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword))); // operators MUST be public

// public static implicit operator BOOL(bool value) => new BOOL(value);
members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ImplicitKeyword), name)
Expand Down
32 changes: 32 additions & 0 deletions test/GenerationSandbox.Tests/BasicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.ComponentModel;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Windows.Win32;
using Windows.Win32.Foundation;
Expand Down Expand Up @@ -51,6 +52,37 @@ public void Bool()
Assert.False(default(BOOL));
}

[Theory]
[InlineData(3)]
[InlineData(-1)]
public void NotLossyConversionBetweenBoolAndBOOL(int ordinal)
{
BOOL nativeBool = new BOOL(ordinal);
bool managedBool = nativeBool;
BOOL roundtrippedNativeBool = managedBool;
Assert.Equal(nativeBool, roundtrippedNativeBool);
}

[Theory]
[InlineData(3)]
[InlineData(-1)]
public void NotLossyConversionBetweenBoolAndBOOL_Ctors(int ordinal)
{
BOOL nativeBool = new BOOL(ordinal);
bool managedBool = nativeBool;
BOOL roundtrippedNativeBool = new BOOL(managedBool);
Assert.Equal(nativeBool, roundtrippedNativeBool);
}

[Fact]
public void BOOLEqualsComparesExactValue()
{
BOOL b1 = new BOOL(1);
BOOL b2 = new BOOL(2);
Assert.Equal(b1, b1);
Assert.NotEqual(b1, b2);
}

[Fact]
public void BSTR_ToString()
{
Expand Down
35 changes: 24 additions & 11 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,17 @@ public void CreateFileUsesSafeHandles()
&& createFileMethod.ParameterList.Parameters.Last().Type?.ToString() == "SafeHandle");
}

/// <summary>
/// GetMessage should return BOOL rather than bool because it actually returns any of THREE values.
/// </summary>
[Fact]
public void BOOL_ReturnTypeBecomes_Boolean()
public void GetMessageW_ReturnsBOOL()
{
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("WinUsb_FlushPipe", CancellationToken.None));
Assert.True(this.generator.TryGenerate("GetMessage", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
MethodDeclarationSyntax? createFileMethod = this.FindGeneratedMethod("WinUsb_FlushPipe").FirstOrDefault();
Assert.NotNull(createFileMethod);
Assert.Equal(SyntaxKind.BoolKeyword, Assert.IsType<PredefinedTypeSyntax>(createFileMethod!.ReturnType).Keyword.Kind());
Assert.All(this.FindGeneratedMethod("GetMessage"), method => Assert.True(method.ReturnType is QualifiedNameSyntax { Right: { Identifier: { ValueText: "BOOL" } } }));
}

[Theory, PairwiseData]
Expand Down Expand Up @@ -861,9 +862,13 @@ internal readonly partial struct BOOL
private readonly int value;

internal int Value => this.value;
internal BOOL(bool value) => this.value = value ? 1 : 0;
internal BOOL(bool value) => this.value = Unsafe.As<bool,sbyte>(ref value);
internal BOOL(int value) => this.value = value;
public static implicit operator bool(BOOL value) => value.Value != 0;
public static implicit operator bool(BOOL value)
{
sbyte v = checked((sbyte)value.value);
return Unsafe.As<sbyte,bool>(ref v);
}
public static implicit operator BOOL(bool value) => new BOOL(value);
public static explicit operator BOOL(int value) => new BOOL(value);
}
Expand Down Expand Up @@ -1140,9 +1145,13 @@ internal readonly partial struct BOOL
private readonly int value;

internal int Value => this.value;
internal BOOL(bool value) => this.value = value ? 1 : 0;
internal BOOL(bool value) => this.value = Unsafe.As<bool,sbyte>(ref value);
internal BOOL(int value) => this.value = value;
public static implicit operator bool(BOOL value) => value.Value != 0;
public static implicit operator bool(BOOL value)
{
sbyte v = checked((sbyte)value.value);
return Unsafe.As<sbyte,bool>(ref v);
}
public static implicit operator BOOL(bool value) => new BOOL(value);
public static explicit operator BOOL(int value) => new BOOL(value);
}
Expand Down Expand Up @@ -1460,9 +1469,13 @@ internal readonly partial struct BOOL
private readonly int value;

internal int Value => this.value;
internal BOOL(bool value) => this.value = value ? 1 : 0;
internal BOOL(bool value) => this.value = Unsafe.As<bool,sbyte>(ref value);
internal BOOL(int value) => this.value = value;
public static implicit operator bool(BOOL value) => value.Value != 0;
public static implicit operator bool(BOOL value)
{
sbyte v = checked((sbyte)value.value);
return Unsafe.As<sbyte,bool>(ref v);
}
public static implicit operator BOOL(bool value) => new BOOL(value);
public static explicit operator BOOL(int value) => new BOOL(value);
}
Expand Down