Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,5 @@ FodyWeavers.xsd
*.msp

# JetBrains Rider
.idea/
*.sln.iml
53 changes: 42 additions & 11 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disa
private const string Func = "Func";
private const string Generic = "Generic";
private const string System = "System";
private const string Tasks = "Tasks";
private const string Threading = "Threading";

// Type names
private const string Memory = nameof(Memory<>);
private const string IAsyncResult = nameof(global::System.IAsyncResult);
private const string Object = "object";
private const string TaskName = nameof(Task);
private const string TimeProviderTaskExtensions = "TimeProviderTaskExtensions";

// Members
private const string CompletedTask = nameof(Task.CompletedTask);
Expand All @@ -41,6 +45,8 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disa
private const string SystemFunc = $"{System}.{Func}";
private const string IEnumerable = $"{System}.{Collections}.{Generic}.{nameof(IEnumerable<>)}";
private const string IEnumerator = $"{System}.{Collections}.{Generic}.{nameof(IEnumerator<>)}";
private const string TaskFullyQualified = $"{System}.{Threading}.{Tasks}.{TaskName}";
private const string TimeProviderTaskExtensionsFullyQualified = $"{System}.{Threading}.{Tasks}.{TimeProviderTaskExtensions}";

private static readonly SymbolDisplayFormat GlobalDisplayFormat = new(
globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included,
Expand Down Expand Up @@ -1679,17 +1685,6 @@ private static bool CanDropEmptyStatement(StatementSyntax statement)
_ => false,
};

private static bool EndsWithAsync(ExpressionSyntax expression) => ReplaceAsync(expression) is not null;

private static string? ReplaceAsync(ExpressionSyntax expression) => expression switch
{
IdentifierNameSyntax { Identifier: { ValueText: not WaitAsync } z } when TryStripAsync(z.ValueText, out var newName) => newName,
MemberAccessExpressionSyntax m when ReplaceAsync(m.Name) is { } newName => newName,
InvocationExpressionSyntax ie => ReplaceAsync(ie.Expression),
GenericNameSyntax gn when TryStripAsync(gn.Identifier.Text, out var newName) => newName,
_ => null,
};

private static TypeSyntax GetReturnType(TypeSyntax returnType, INamedTypeSymbol symbol) => (returnType switch
{
IdentifierNameSyntax => ProcessSymbol(symbol),
Expand Down Expand Up @@ -1830,6 +1825,42 @@ private static List<SyntaxTrivia> RemoveFirstEndIf(SyntaxTriviaList list)
return newLeadingTrivia;
}

private bool EndsWithAsync(ExpressionSyntax expression) => expression switch
{
IdentifierNameSyntax id => ReplaceAsync(id) is not null,
MemberAccessExpressionSyntax m => EndsWithAsync(m.Name) || EndsWithAsync(m.Expression),
InvocationExpressionSyntax ie => EndsWithAsync(ie.Expression),
GenericNameSyntax gn => ReplaceAsync(gn) is not null,
_ => false,
};

private string? ReplaceAsync(ExpressionSyntax expression) => expression switch
{
IdentifierNameSyntax id => TryReplaceIdentifier(id),
MemberAccessExpressionSyntax m when ReplaceAsync(m.Name) is { } newName => newName,
InvocationExpressionSyntax ie => ReplaceAsync(ie.Expression),
GenericNameSyntax gn when TryStripAsync(gn.Identifier.Text, out var newName) => newName,
_ => null,
};

private string? TryReplaceIdentifier(IdentifierNameSyntax id)
{
if (id.Identifier.ValueText is WaitAsync)
{
var symbol = semanticModel.GetSymbolInfo(id).Symbol as IMethodSymbol;
if (symbol?.ContainingType?.ToDisplayString() is { } containingType)
{
if (containingType.StartsWith(TaskFullyQualified, StringComparison.Ordinal) ||
containingType.StartsWith(TimeProviderTaskExtensionsFullyQualified, StringComparison.Ordinal))
{
return null;
}
}
}

return TryStripAsync(id.Identifier.ValueText, out var newName) ? newName : null;
}

private InvocationExpressionSyntax UnwrapExtension(InvocationExpressionSyntax ies, bool changeMemoryToSpan, IMethodSymbol reducedFrom, ExpressionSyntax expression)
{
var arguments = ies.ArgumentList.Arguments;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//HintName: Test.Class.MethodAsync.g.cs
// <auto-generated/>
#nullable enable
namespace Test;
public partial class Class
{
public void Method()
{
_ = reader.Read();
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
//HintName: Test.Class.MethodAsync.g.cs
var semaphore = new global::System.Threading.SemaphoreSlim(1, 1);

semaphore.Wait();

try
{
}
finally
// <auto-generated/>
#nullable enable
namespace Test;
public partial class Class
{
semaphore.Release();
public void Method()
{
semaphore.Wait();

try
{
global::System.Threading.Thread.Sleep(100);
}
finally
{
semaphore.Release();
}
}
}
16 changes: 16 additions & 0 deletions tests/Generator.Tests/TaskTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ public async Task MethodAsync(XmlReader reader, CancellationToken ct)
}
""".Verify();

[Fact]
public Task DropWaitAsyncFullSource() => """
namespace Test;

public partial class Class
{
private XmlReader reader;

[CreateSyncVersion]
public async Task MethodAsync(CancellationToken ct = default)
{
_ = await reader.ReadAsync().WaitAsync(ct);
}
}
""".Verify(sourceType: SourceType.Full);

[Fact]
public Task DropWaitAsyncStatement() => """
[CreateSyncVersion]
Expand Down
28 changes: 19 additions & 9 deletions tests/Generator.Tests/TypeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,26 @@ public Task CreateNullableType()
[Fact]
public Task SemaphoreSlimWaitAndRelease()
=> """
var semaphore = new SemaphoreSlim(1, 1);

await semaphore.WaitAsync();
namespace Test;

try
{
}
finally
public partial class Class
{
semaphore.Release();
private SemaphoreSlim semaphore = new(1, 1);

[CreateSyncVersion]
public async Task MethodAsync(CancellationToken ct = default)
{
await semaphore.WaitAsync(ct);

try
{
await Task.Delay(100, ct);
}
finally
{
semaphore.Release();
}
}
}
""".Verify(sourceType: SourceType.MethodBody);
""".Verify(sourceType: SourceType.Full);
}