Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,14 @@ private void WriteResolverBindingExtendsWith(

WriteResolverBindingDescriptor(type, resolver);

if (resolver.SubscribeWith is not null)
{
Writer.WriteIndentedLine(
"configuration.SubscribeWith = \"{0}\";",
GeneratorUtils.EscapeForStringLiteral(resolver.SubscribeWith));
Writer.WriteIndentedLine("configuration.SourceType = context.ThisType;");
}

if (resolver.Kind is ResolverKind.BatchResolver)
{
// For batch resolvers, the return type is a list (e.g. List<string>).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ public bool TryHandle(GeneratorSyntaxContext context, [NotNullWhen(true)] out Sy
Resolver? nodeResolver = null;
var i = 0;

CollectSubscribeWithNames(
members,
includeInternalMembers,
out var subscribeSourceNames,
out var subscribeWithLookup);

foreach (var member in members)
{
if (member.IsIgnored())
Expand All @@ -83,13 +89,23 @@ or Accessibility.ProtectedOrInternal
continue;
}

// Skip methods that are referenced by a sibling [Subscribe(With = nameof(...))] attribute.
// These methods produce the event stream and must not be exposed as their own GraphQL fields.
if (subscribeSourceNames?.Contains(methodSymbol.Name) == true)
{
continue;
}

if (!isOperationType && hasNodeResolverAttribute)
{
nodeResolver = CreateNodeResolver(context, classSymbol, methodSymbol, ref diagnostics);
continue;
}

resolvers[i++] = CreateResolver(context, classSymbol, methodSymbol);
string? subscribeWith = null;
subscribeWithLookup?.TryGetValue(methodSymbol, out subscribeWith);

resolvers[i++] = CreateResolver(context, classSymbol, methodSymbol, subscribeWith);
continue;
}

Expand Down Expand Up @@ -277,14 +293,16 @@ private static bool IsOperationType(
private static Resolver CreateResolver(
GeneratorSyntaxContext context,
INamedTypeSymbol resolverType,
IMethodSymbol resolverMethod)
=> CreateResolver(context.SemanticModel.Compilation, resolverType, resolverMethod);
IMethodSymbol resolverMethod,
string? subscribeWith = null)
=> CreateResolver(context.SemanticModel.Compilation, resolverType, resolverMethod, subscribeWith: subscribeWith);

public static Resolver CreateResolver(
Compilation compilation,
INamedTypeSymbol resolverType,
IMethodSymbol resolverMethod,
string? resolverTypeName = null)
string? resolverTypeName = null,
string? subscribeWith = null)
{
var parameters = resolverMethod.Parameters;
var buffer = new ResolverParameter[parameters.Length];
Expand Down Expand Up @@ -324,7 +342,78 @@ public static Resolver CreateResolver(
? ResolverKind.BatchResolver
: compilation.IsConnectionType(resolverMethod.ReturnType)
? ResolverKind.ConnectionResolver
: ResolverKind.Default);
: ResolverKind.Default,
subscribeWith: subscribeWith);
}

private static void CollectSubscribeWithNames(
ImmutableArray<ISymbol> members,
bool includeInternalMembers,
out HashSet<string>? subscribeSourceNames,
out Dictionary<IMethodSymbol, string>? subscribeWithLookup)
{
subscribeSourceNames = null;
subscribeWithLookup = null;

foreach (var member in members)
{
if (member.IsIgnored())
{
continue;
}

if (member is not IMethodSymbol { MethodKind: MethodKind.Ordinary } method)
{
continue;
}

if (!IsVisibleResolverMember(method, includeInternalMembers))
{
continue;
}

if (method.Skip())
{
continue;
}

if (!TryGetSubscribeWith(method, out var with))
{
continue;
}

subscribeSourceNames ??= [];
subscribeWithLookup ??= new Dictionary<IMethodSymbol, string>(SymbolEqualityComparer.Default);
subscribeSourceNames.Add(with);
subscribeWithLookup[method] = with;
}
}

private static bool TryGetSubscribeWith(
IMethodSymbol methodSymbol,
[NotNullWhen(true)] out string? with)
{
foreach (var attribute in methodSymbol.GetAttributes())
{
if (attribute.AttributeClass?.ToDisplayString() != SubscribeAttribute)
{
continue;
}

foreach (var namedArg in attribute.NamedArguments)
{
if (namedArg.Key == "With"
&& namedArg.Value.Value is string value
&& !string.IsNullOrEmpty(value))
{
with = value;
return true;
}
}
}

with = null;
return false;
}

private static Resolver CreateNodeResolver(
Expand Down
13 changes: 11 additions & 2 deletions src/HotChocolate/Core/src/Types.Analyzers/Models/Resolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public Resolver(
ImmutableArray<MemberBinding> bindings,
SchemaTypeReference schemaTypeRef,
ResolverKind kind = ResolverKind.Default,
FieldFlags flags = FieldFlags.None)
FieldFlags flags = FieldFlags.None,
string? subscribeWith = null)
{
TypeName = typeName;
Member = member;
Expand All @@ -31,6 +32,7 @@ public Resolver(
Bindings = bindings;
Kind = kind;
Flags = flags;
SubscribeWith = subscribeWith;

if (description is MethodDescription m && parameters.Length == m.ParameterDescriptions.Length)
{
Expand Down Expand Up @@ -97,6 +99,12 @@ public bool IsPure

public ImmutableArray<AttributeData> DescriptorAttributes { get; }

/// <summary>
/// The name of the sibling method that produces the subscription event stream
/// when this resolver is annotated with <c>[Subscribe(With = nameof(...))]</c>.
/// </summary>
public string? SubscribeWith { get; }

public Resolver WithSchemaTypeName(SchemaTypeReference schemaTypeRef)
=> new Resolver(
TypeName,
Expand All @@ -108,5 +116,6 @@ public Resolver WithSchemaTypeName(SchemaTypeReference schemaTypeRef)
Bindings,
schemaTypeRef,
Kind,
Flags);
Flags,
SubscribeWith);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static class WellKnownAttributes
public const string QueryAttribute = "HotChocolate.QueryAttribute";
public const string MutationAttribute = "HotChocolate.MutationAttribute";
public const string SubscriptionAttribute = "HotChocolate.SubscriptionAttribute";
public const string SubscribeAttribute = "HotChocolate.Types.SubscribeAttribute";
public const string NodeResolverAttribute = "HotChocolate.Types.Relay.NodeResolverAttribute";
public const string ParentAttribute = "HotChocolate.ParentAttribute";
public const string EventMessageAttribute = "HotChocolate.EventMessageAttribute";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@ public async Task Schema_Snapshot()
.MatchSnapshotAsync();
}

[Fact]
public async Task Subscription_With_Subscribe_With_Delivers_Message_From_Stream()
{
// arrange
var executor = await new ServiceCollection()
.AddGraphQLServer()
.AddIntegrationTestTypes()
.AddPagingArguments()
.BuildRequestExecutorAsync();

// act
await using var subscriptionResult = await executor.ExecuteAsync(
"subscription { onProductAdded(categoryId: 42) }");

// assert
var stream = subscriptionResult.ExpectResponseStream();
await foreach (var result in stream.ReadResultsAsync())
{
result.MatchInlineSnapshot(
"""
{
"data": {
"onProductAdded": 42
}
}
""");
break;
}
}

[Fact]
public async Task Subscription_With_Public_Subscribe_Source_Is_Not_Exposed_As_Field()
{
var schema = await new ServiceCollection()
.AddGraphQLServer()
.AddIntegrationTestTypes()
.AddPagingArguments()
.BuildSchemaAsync();

var subscription = schema.Types.GetType<ObjectType>("Subscription");
Assert.Equal(
["onProductAdded", "onProductPriceChanged"],
subscription.Fields.Where(f => !f.IsIntrospectionField).Select(f => f.Name).ToArray());
}

[Fact]
public async Task Maps_NullOrdering_From_PagingOptions_To_PagingArguments()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace HotChocolate.Types;

[SubscriptionType]
public static partial class Subscription
{
[Subscribe(With = nameof(SubscribeToOnProductAdded))]
public static Task<int> OnProductAdded([EventMessage] int productId)
=> Task.FromResult(productId);

private static async IAsyncEnumerable<int> SubscribeToOnProductAdded(int categoryId)
{
await Task.Yield();
yield return categoryId;
}

[Subscribe(With = nameof(SubscribeToOnProductPriceChanged))]
public static Task<int> OnProductPriceChanged([EventMessage] int newPrice)
=> Task.FromResult(newPrice);

public static async IAsyncEnumerable<int> SubscribeToOnProductPriceChanged(int productId)
{
await Task.Yield();
yield return productId;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
schema {
query: Query
subscription: Subscription
}

type Query {
Expand Down Expand Up @@ -60,6 +61,11 @@ type Query {
issue8057Entity(id: ID!): Issue8057Entity @cost(weight: "10")
}

type Subscription {
onProductAdded(categoryId: Int!): Int! @cost(weight: "10")
onProductPriceChanged(productId: Int!): Int! @cost(weight: "10")
}

type Book implements Product {
title: String!
id: String!
Expand Down
52 changes: 52 additions & 0 deletions src/HotChocolate/Core/test/Types.Analyzers.Tests/OperationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,58 @@ public static int GetTest(string arg)
""").MatchMarkdownAsync();
}

[Fact]
public async Task Subscription_With_Subscribe_With_Excludes_Stream_Method()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Threading.Tasks;
using HotChocolate;
using HotChocolate.Types;

namespace TestNamespace;

[SubscriptionType]
public static partial class Subscription
{
[Subscribe(With = nameof(SubscribeToOnProductAdded))]
public static Task<int> OnProductAdded([EventMessage] int productId)
=> Task.FromResult(productId);

private static async IAsyncEnumerable<int> SubscribeToOnProductAdded(int categoryId)
{
await Task.Yield();
yield return categoryId;
}
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task Subscription_Ignored_Method_Does_Not_Suppress_Public_Resolver()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Threading.Tasks;
using HotChocolate;
using HotChocolate.Types;

namespace TestNamespace;

[SubscriptionType]
public static partial class Subscription
{
public static int OnFoo() => 42;

[GraphQLIgnore]
[Subscribe(With = nameof(OnFoo))]
public static Task<int> NotARealResolver() => Task.FromResult(0);
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task Lookup_With_Generic_ID_Attribute()
{
Expand Down
Loading
Loading