Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/HotChocolate/Core/src/Abstractions/ErrorCodes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ public static class Data
public const string FilteringProjectionFailed = "HC0023";
public const string SortingProjectionFailed = "HC0024";
public const string NoPaginationProviderFound = "HC0025";
public const string MaxFilterOperationsExceeded = "HC0117";

/// <summary>
/// Type does not contain a valid node field. Only `items` and `nodes` are supported
Expand Down
16 changes: 16 additions & 0 deletions src/HotChocolate/Data/src/Data/ErrorHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ public static IError SortingVisitor_ListValues(ISortField field, ListValueNode n
.SetExtension(nameof(field), field)
.Build();

public static IError MaxAllowedFilterOperationsExceeded(
IValueNode node,
int filterOperations,
int maxAllowedFilterOperations) =>
ErrorBuilder.New()
.SetMessage(
"The filter argument contains {0} operations, which exceeds the maximum allowed "
+ "number of {1}.",
filterOperations,
maxAllowedFilterOperations)
.AddLocation(node)
.SetCode(ErrorCodes.Data.MaxFilterOperationsExceeded)
.SetExtension(nameof(filterOperations), filterOperations)
.SetExtension(nameof(maxAllowedFilterOperations), maxAllowedFilterOperations)
.Build();

public static IError CreateNonNullError<T>(
ISortField field,
IValueNode value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public static IFilterConventionDescriptor AddDefaults(
descriptor
.AddDefaultOperations()
.BindDefaultTypes(compatibilityMode)
.MaxAllowedFilterOperations(FilterConventionDefinition.DefaultMaxAllowedFilterOperations)
.UseQueryableProvider();

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace HotChocolate.Data.Filters;
/// </summary>
public class FilterConvention
: Convention<FilterConventionDefinition>
, IFilterConvention
, IFilterConvention
{
private const string _inputPostFix = "FilterInput";
private const string _inputTypePostFix = "FilterInputType";
Expand All @@ -28,6 +28,7 @@ public class FilterConvention
private string _argumentName = default!;
private IFilterProvider _provider = default!;
private ITypeInspector _typeInspector = default!;
private int? _maxAllowedFilterOperations;
private bool _useAnd;
private bool _useOr;

Expand Down Expand Up @@ -99,6 +100,7 @@ protected internal override void Complete(IConventionContext context)
_bindings = Definition.Bindings;
_configs = Definition.Configurations;
_argumentName = Definition.ArgumentName;
_maxAllowedFilterOperations = Definition.MaxAllowedFilterOperations;
_useAnd = Definition.UseAnd;
_useOr = Definition.UseOr;

Expand Down Expand Up @@ -234,6 +236,9 @@ public string GetOperationName(int operation)
/// <inheritdoc />
public string GetArgumentName() => _argumentName;

/// <inheritdoc />
public int? GetMaxAllowedFilterOperations() => _maxAllowedFilterOperations;

/// <inheritdoc cref="IFilterConvention"/>
public void ApplyConfigurations(
TypeReference typeReference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace HotChocolate.Data.Filters;
public class FilterConventionDefinition : IHasScope
{
public static readonly string DefaultArgumentName = "where";
public const int DefaultMaxAllowedFilterOperations = 64;
private string _argumentName = DefaultArgumentName;

public string? Scope { get; set; }
Expand All @@ -33,6 +34,8 @@ public string ArgumentName

public List<Type> ProviderExtensionsTypes { get; } = [];

public int? MaxAllowedFilterOperations { get; set; }

public bool UseOr { get; set; } = true;

public bool UseAnd { get; set; } = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ public IFilterConventionDescriptor ArgumentName(string argumentName)
return this;
}

/// <inheritdoc />
public IFilterConventionDescriptor MaxAllowedFilterOperations(int? maxAllowedFilterOperations)
{
if (maxAllowedFilterOperations is < 1)
{
throw new ArgumentOutOfRangeException(
nameof(maxAllowedFilterOperations),
maxAllowedFilterOperations,
"The maximum number of filter operations must be greater than zero.");
}

Definition.MaxAllowedFilterOperations = maxAllowedFilterOperations;

return this;
}

public IFilterConventionDescriptor AddProviderExtension<TExtension>()
where TExtension : class, IFilterProviderExtension
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ Definition is not null &&
filterConvention.Definition.ArgumentName = Definition.ArgumentName;
}

if (Definition.MaxAllowedFilterOperations.HasValue)
{
filterConvention.Definition.MaxAllowedFilterOperations =
Definition.MaxAllowedFilterOperations;
}

if (Definition.Provider is not null)
{
filterConvention.Definition.Provider = Definition.Provider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ public interface IFilterConvention : IConvention
/// </returns>
string GetArgumentName();

/// <summary>
/// Gets the maximum number of filter operations allowed in a single filter argument.
/// </summary>
int? GetMaxAllowedFilterOperations();

/// <summary>
/// Applies configurations to a filter type.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ IFilterConventionDescriptor Provider<TProvider>(TProvider provider)
/// </exception>
IFilterConventionDescriptor ArgumentName(string argumentName);

/// <summary>
/// Defines the maximum number of filter operations allowed in a single filter argument.
/// </summary>
/// <param name="maxAllowedFilterOperations">
/// The maximum number of filter operations. If <c>null</c>, no limit is applied.
/// </param>
IFilterConventionDescriptor MaxAllowedFilterOperations(int? maxAllowedFilterOperations);

/// <summary>
/// Add a extensions that is applied to <see cref="FilterProvider{TContext}"/>
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,45 @@ public override bool TryCombineOperations(
return false;
}

combined = operations.Dequeue();
combined = CombineOperations(operations, combinator);

while (operations.Count > 0)
return true;
}

private Expression CombineOperations(
Queue<Expression> operations,
FilterCombinator combinator)
{
while (operations.Count > 1)
{
combined = combinator switch
var operationCount = operations.Count;
var pairCount = operationCount / 2;

for (var i = 0; i < pairCount; i++)
{
FilterCombinator.And => Expression.AndAlso(combined, operations.Dequeue()),
FilterCombinator.Or => Expression.OrElse(combined, operations.Dequeue()),
_ => throw ThrowHelper
.Filtering_QueryableCombinator_InvalidCombinator(this, combinator),
};
var left = operations.Dequeue();
var right = operations.Dequeue();

operations.Enqueue(Combine(left, right, combinator));
}

if ((operationCount & 1) == 1)
{
operations.Enqueue(operations.Dequeue());
}
}

return true;
return operations.Dequeue();
}

private BinaryExpression Combine(
Expression left,
Expression right,
FilterCombinator combinator)
=> combinator switch
{
FilterCombinator.And => Expression.AndAlso(left, right),
FilterCombinator.Or => Expression.OrElse(left, right),
_ => throw ThrowHelper.Filtering_QueryableCombinator_InvalidCombinator(this, combinator),
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public override void ConfigureField(
string argumentName,
IObjectFieldDescriptor descriptor)
{
var maxAllowedFilterOperations = FilterConvention.GetMaxAllowedFilterOperations();
var contextData = descriptor.Extend().Definition.ContextData;
var argumentKey = (VisitFilterArgument)VisitFilterArgumentExecutor;
contextData[ContextVisitFilterArgumentKey] = argumentKey;
Expand All @@ -111,6 +112,11 @@ QueryableFilterContext VisitFilterArgumentExecutor(
IFilterInputType filterInput,
bool inMemory)
{
if (maxAllowedFilterOperations is { } maxAllowed)
{
ValidateFilterOperations(valueNode, filterInput, maxAllowed);
}

var visitorContext = new QueryableFilterContext(filterInput, inMemory);

// rewrite GraphQL input object into expression tree.
Expand All @@ -120,6 +126,89 @@ QueryableFilterContext VisitFilterArgumentExecutor(
}
}

private static void ValidateFilterOperations(
IValueNode valueNode,
IFilterInputType filterInput,
int maxAllowedFilterOperations)
{
var filterOperations = 0;
var stack = new Stack<(IValueNode Value, IInputType Type)>();
stack.Push((valueNode, filterInput));

while (stack.Count > 0)
{
var (value, type) = stack.Pop();

switch (value)
{
case ObjectValueNode objectValue
when type.NamedType() is InputObjectType inputObject:
for (var i = objectValue.Fields.Count - 1; i >= 0; i--)
{
var fieldValue = objectValue.Fields[i];

if (!inputObject.Fields.TryGetField(fieldValue.Name.Value, out var field))
{
continue;
}

if (field is IFilterOperationField and not IAndField and not IOrField)
{
filterOperations++;

if (filterOperations > maxAllowedFilterOperations)
{
throw new GraphQLException(
ErrorHelper.MaxAllowedFilterOperationsExceeded(
fieldValue.Value,
filterOperations,
maxAllowedFilterOperations));
}
}

if (CanContainFilterOperations(field.Type, fieldValue.Value))
{
stack.Push((fieldValue.Value, field.Type));
}
}

break;

case ListValueNode listValue
when TryGetListElementType(type, out var elementType):
for (var i = listValue.Items.Count - 1; i >= 0; i--)
{
stack.Push((listValue.Items[i], elementType));
}

break;
}
}
}

private static bool CanContainFilterOperations(IInputType type, IValueNode value)
=> value switch
{
ObjectValueNode => type.NamedType() is InputObjectType,
ListValueNode => TryGetListElementType(type, out var elementType)
&& elementType.NamedType() is InputObjectType,
_ => false,
};

private static bool TryGetListElementType(
IInputType type,
[NotNullWhen(true)] out IInputType? elementType)
{
if (type.IsListType() && type.ElementType() is IInputType inputType)
{
elementType = inputType;
return true;
}

elementType = null;
return false;
}

/// <inheritdoc />
public override IFilterMetadata? CreateMetaData(
ITypeCompletionContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public FilterProvider(Action<IFilterProviderDescriptor<TContext>> configure)
/// <inheritdoc />
public IReadOnlyCollection<IFilterFieldHandler> FieldHandlers => _fieldHandlers;

protected IFilterConvention FilterConvention
=> _filterConvention ?? throw FilterConvention_ProviderHasToBeInitializedByConvention(GetType(), Scope);

/// <inheritdoc />
protected override FilterProviderDefinition CreateDefinition(IConventionContext context)
{
Expand Down
Loading
Loading