diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 79519d96e..7cd0efed4 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -23,11 +23,12 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AggregateFunction, Expr, - ExprWithAlias, + AliasedAggregate, + AliasedExpr, Field, - FilterCondition, + BooleanExpr, Selectable, Ordering, ) @@ -164,8 +165,8 @@ class Aggregate(Stage): def __init__( self, - *args: ExprWithAlias[Accumulator], - accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + *args: AliasedExpr[AggregateFunction], + accumulators: Sequence[AliasedAggregate] = (), groups: Sequence[str | Selectable] = (), ): super().__init__() @@ -350,6 +351,26 @@ def _pb_args(self) -> list[Value]: return [f._to_pb() for f in self.fields] +class Replace(Stage): + """Replaces the document content with the value of a specified field.""" + + class Mode(Enum): + FULL_REPLACE = 0 + MERGE_PREFER_NEXT = 1 + MERGE_PREFER_PARENT = 2 + + def __repr__(self): + return f"Replace.Mode.{self.name.upper()}" + + def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): + super().__init__() + self.field = Field(field) if isinstance(field, str) else field + self.mode = self.Mode[mode.upper()] if isinstance(mode, str) else mode + + def _pb_args(self): + return [self.field._to_pb(), Value(string_value=self.mode.name.lower())] + + class Sample(Stage): """Performs pseudo-random sampling of documents.""" @@ -439,7 +460,7 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" - def __init__(self, condition: FilterCondition): + def __init__(self, condition: BooleanExpr): super().__init__() self.condition = condition diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 89e4edd0e..ba60e1314 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -34,9 +34,9 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) -from google.cloud.firestore_v1.pipeline_expressions import Accumulator +from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction from google.cloud.firestore_v1.pipeline_expressions import Count -from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints @@ -86,7 +86,7 @@ def _to_protobuf(self): @abc.abstractmethod def _to_pipeline_expr( self, autoindexer: Iterable[int] - ) -> ExprWithAlias[Accumulator]: + ) -> AliasedExpr[AggregateFunction]: """ Convert this instance to a pipeline expression for use with pipeline.aggregate() @@ -162,7 +162,7 @@ def _to_protobuf(self): return aggregation_pb def _to_pipeline_expr(self, autoindexer: Iterable[int]): - return Field.of(self.field_ref).avg().as_(self._pipeline_alias(autoindexer)) + return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) def _query_response_to_result( diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index b1d8a4cf3..ac917c5d2 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -23,11 +23,10 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AliasedAggregate, Expr, - ExprWithAlias, Field, - FilterCondition, + BooleanExpr, Selectable, ) from google.cloud.firestore_v1 import _helpers @@ -220,14 +219,14 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": """ return self._append(stages.Select(*selections)) - def where(self, condition: FilterCondition) -> "_BasePipeline": + def where(self, condition: BooleanExpr) -> "_BasePipeline": """ Filters the documents from previous stages to only include those matching - the specified `FilterCondition`. + the specified `BooleanExpr`. This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using - implementations of `FilterCondition`, typically including but not limited to: + implementations of `BooleanExpr`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. @@ -252,7 +251,7 @@ def where(self, condition: FilterCondition) -> "_BasePipeline": Args: - condition: The `FilterCondition` to apply. + condition: The `BooleanExpr` to apply. Returns: A new Pipeline object with this stage appended to the stage list @@ -343,6 +342,54 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline": """ return self._append(stages.Sort(*orders)) + def replace( + self, + field: Selectable, + mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, + ) -> "_BasePipeline": + """ + Replaces the entire document content with the value of a specified field, + typically a map. + + This stage allows you to emit a map value as the new document structure. + Each key of the map becomes a field in the output document, containing the + corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("people") + >>> # Emit the 'parents' map as the document + >>> pipeline = pipeline.replace(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + mode: The replacement mode + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Replace(field, mode)) + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": """ Performs a pseudo-random sampling of the documents from the previous stage. @@ -531,7 +578,7 @@ def limit(self, limit: int) -> "_BasePipeline": def aggregate( self, - *accumulators: ExprWithAlias[Accumulator], + *accumulators: AliasedAggregate, groups: Sequence[str | Selectable] = (), ) -> "_BasePipeline": """ @@ -541,7 +588,7 @@ def aggregate( This stage allows you to calculate aggregate values (like sum, average, count, min, max) over a set of documents. - - **Accumulators:** Define the aggregation calculations using `Accumulator` + - **AggregateFunctions:** Define the aggregation calculations using `AggregateFunction` expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined with `as_()` to name the result field. - **Groups:** Optionally specify fields (by name or `Selectable`) to group @@ -569,7 +616,7 @@ def aggregate( Args: - *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + *accumulators: One or more `AliasedAggregate` expressions defining the aggregations to perform and their output names. groups: An optional sequence of field names (str) or `Selectable` expressions to group by before aggregating. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 67c4f27fd..797572b1b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1151,7 +1151,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: ppl = ppl.where( - pipeline_expressions.FilterCondition._from_query_filter_pb( + pipeline_expressions.BooleanExpr._from_query_filter_pb( filter_, self._client ) ) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index a04e39ca5..30953357c 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -116,7 +116,83 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Expr | float) -> "Add": + class expose_as_static: + """ + Decorator to mark instance methods to be exposed as static methods as well as instance + methods. + + When called statically, the first argument is converted to a Field expression if needed. + + Example: + >>> Field.of("test").add(5) + >>> Function.add("test", 5) + """ + + def __init__(self, instance_func): + self.instance_func = instance_func + + def static_func(self, first_arg, *other_args, **kwargs): + first_expr = Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg + return self.instance_func(first_expr, *other_args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self.static_func.__get__(instance, owner) + else: + return self.instance_func.__get__(instance, owner) + + @staticmethod + def array(elements: list[Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz")]) + + Args: + elements: THe input list to evaluate in the expression + + Returns: + A new `Expr` representing the array function. + """ + return Array([Expr._cast_to_expr_or_convert_to_constant(e) for e in elements]) + + @staticmethod + def map(elements: dict[str, Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: THe input dict to evaluate in the expression + + Returns: + A new `Expr` representing the map function. + """ + return Map({Constant.of(k): Expr._cast_to_expr_or_convert_to_constant(v) for k, v in elements.items()}) + + @staticmethod + def conditional(conditional: BooleanExpr, then_expr: Expr, else_expr: Expr) -> "Expr": + """ + Creates a conditional expression that evaluates to a 'then' expression if a condition is true + and an 'else' expression if the condition is false. + + Example: + >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". + >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); + + Args: + conditional: The condition to evaluate. + then_expr: The expression to return if the condition is true. + else_expr: The expression to return if the condition is false + + Returns: + A new `Expr` representing the conditional expression. + """ + return Conditional(conditional, then_expr, else_expr) + + @expose_as_static + def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. Example: @@ -133,7 +209,8 @@ def add(self, other: Expr | float) -> "Add": """ return Add(self, self._cast_to_expr_or_convert_to_constant(other)) - def subtract(self, other: Expr | float) -> "Subtract": + @expose_as_static + def subtract(self, other: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -150,7 +227,8 @@ def subtract(self, other: Expr | float) -> "Subtract": """ return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) - def multiply(self, other: Expr | float) -> "Multiply": + @expose_as_static + def multiply(self, other: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -167,7 +245,8 @@ def multiply(self, other: Expr | float) -> "Multiply": """ return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) - def divide(self, other: Expr | float) -> "Divide": + @expose_as_static + def divide(self, other: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. Example: @@ -184,7 +263,8 @@ def divide(self, other: Expr | float) -> "Divide": """ return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) - def mod(self, other: Expr | float) -> "Mod": + @expose_as_static + def mod(self, other: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -201,7 +281,135 @@ def mod(self, other: Expr | float) -> "Mod": """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + @expose_as_static + def abs(self) -> "Abs": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expr` representing the absolute value. + """ + return Abs(self) + + @expose_as_static + def ceil(self) -> "Ceil": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expr` representing the ceiling value. + """ + return Ceil(self) + + @expose_as_static + def exp(self) -> "Exp": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expr` representing the exponential value. + """ + return Exp(self) + + @expose_as_static + def floor(self) -> "Floor": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expr` representing the floor value. + """ + return Floor(self) + + @expose_as_static + def ln(self) -> "Ln": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expr` representing the natural logarithm. + """ + return Ln(self) + + @expose_as_static + def log(self, base: Expr | float) -> "Log": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + return Log(self, self._cast_to_expr_or_convert_to_constant(base)) + + @expose_as_static + def pow(self, exponent: Expr | float) -> "Pow": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + return Pow(self, self._cast_to_expr_or_convert_to_constant(exponent)) + + @expose_as_static + def round(self) -> "Round": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expr` representing the rounded value. + """ + return Round(self) + + @expose_as_static + def sqrt(self) -> "Sqrt": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expr` representing the square root. + """ + return Sqrt(self) + + @expose_as_static + def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -210,19 +418,20 @@ def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": Example: >>> # Returns the larger value between the 'discount' field and the 'cap' field. - >>> Field.of("discount").logical_max(Field.of("cap")) + >>> Field.of("discount").logical_maximum(Field.of("cap")) >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_max(10) + >>> Field.of("value").logical_maximum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical max operation. + A new `Expr` representing the logical maximum operation. """ - return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + return LogicalMaximum(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + @expose_as_static + def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -231,27 +440,28 @@ def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": Example: >>> # Returns the smaller value between the 'discount' field and the 'floor' field. - >>> Field.of("discount").logical_min(Field.of("floor")) + >>> Field.of("discount").logical_minimum(Field.of("floor")) >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_min(10) + >>> Field.of("value").logical_minimum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical min operation. + A new `Expr` representing the logical minimum operation. """ - return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + return LogicalMinimum(self, self._cast_to_expr_or_convert_to_constant(other)) - def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + @expose_as_static + def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. Example: >>> # Check if the 'age' field is equal to 21 - >>> Field.of("age").eq(21) + >>> Field.of("age").equal(21) >>> # Check if the 'city' field is equal to "London" - >>> Field.of("city").eq("London") + >>> Field.of("city").equal("London") Args: other: The expression or constant value to compare for equality. @@ -259,17 +469,18 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": Returns: A new `Expr` representing the equality comparison. """ - return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + return Equal(self, self._cast_to_expr_or_convert_to_constant(other)) - def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + @expose_as_static + def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. Example: >>> # Check if the 'status' field is not equal to "completed" - >>> Field.of("status").neq("completed") + >>> Field.of("status").not_equal("completed") >>> # Check if the 'country' field is not equal to "USA" - >>> Field.of("country").neq("USA") + >>> Field.of("country").not_equal("USA") Args: other: The expression or constant value to compare for inequality. @@ -277,17 +488,18 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": Returns: A new `Expr` representing the inequality comparison. """ - return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + return NotEqual(self, self._cast_to_expr_or_convert_to_constant(other)) - def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + @expose_as_static + def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. Example: >>> # Check if the 'age' field is greater than the 'limit' field - >>> Field.of("age").gt(Field.of("limit")) + >>> Field.of("age").greater_than(Field.of("limit")) >>> # Check if the 'price' field is greater than 100 - >>> Field.of("price").gt(100) + >>> Field.of("price").greater_than(100) Args: other: The expression or constant value to compare for greater than. @@ -295,17 +507,18 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": Returns: A new `Expr` representing the greater than comparison. """ - return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + return GreaterThan(self, self._cast_to_expr_or_convert_to_constant(other)) - def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + @expose_as_static + def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 - >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> Field.of("quantity").greater_than_or_equal(Field.of('requirement').add(1)) >>> # Check if the 'score' field is greater than or equal to 80 - >>> Field.of("score").gte(80) + >>> Field.of("score").greater_than_or_equal(80) Args: other: The expression or constant value to compare for greater than or equal to. @@ -313,17 +526,20 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": Returns: A new `Expr` representing the greater than or equal to comparison. """ - return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + return GreaterThanOrEqual( + self, self._cast_to_expr_or_convert_to_constant(other) + ) - def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + @expose_as_static + def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. Example: >>> # Check if the 'age' field is less than 'limit' - >>> Field.of("age").lt(Field.of('limit')) + >>> Field.of("age").less_than(Field.of('limit')) >>> # Check if the 'price' field is less than 50 - >>> Field.of("price").lt(50) + >>> Field.of("price").less_than(50) Args: other: The expression or constant value to compare for less than. @@ -331,17 +547,18 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": Returns: A new `Expr` representing the less than comparison. """ - return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + return LessThan(self, self._cast_to_expr_or_convert_to_constant(other)) - def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + @expose_as_static + def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is less than or equal to 20 - >>> Field.of("quantity").lte(Constant.of(20)) + >>> Field.of("quantity").less_than_or_equal(Constant.of(20)) >>> # Check if the 'score' field is less than or equal to 70 - >>> Field.of("score").lte(70) + >>> Field.of("score").less_than_or_equal(70) Args: other: The expression or constant value to compare for less than or equal to. @@ -349,15 +566,16 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": Returns: A new `Expr` representing the less than or equal to comparison. """ - return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + return LessThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) - def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "In": + @expose_as_static + def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + >>> Field.of("category").equal_any(["Electronics", Field.of("primaryType")]) Args: array: The values or expressions to check against. @@ -365,15 +583,16 @@ def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "In": Returns: A new `Expr` representing the 'IN' comparison. """ - return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + return EqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not": + @expose_as_static + def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any(["pending", "cancelled"]) + >>> Field.of("status").not_equal_any(["pending", "cancelled"]) Args: array: The values or expressions to check against. @@ -381,9 +600,46 @@ def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not": Returns: A new `Expr` representing the 'NOT IN' comparison. """ - return Not(self.in_any(array)) + return NotEqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + + @expose_as_static + def array_get(self, index: Expr | int) -> "Expr": + """Creates an expression that indexes into an array from the beginning or end + and returns the element. If the index exceeds the array length, an error is + returned. A negative index, starts from the end. + + Example: + >>> # Return the value in the tags field array at index specified by field 'favoriteTag'. + >>> Field.of("tags").array_get(Field.of("favoriteTag")) + + Args: + index: The index of the element to return. + + Returns: + A new `Expr` representing the operation. + """ + return ArrayGet(self, self._cast_to_expr_or_convert_to_constant(index)) + + @expose_as_static + def array_concat(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + return ArrayConcat( + self, [self._cast_to_expr_or_convert_to_constant(o) for o in array] + ) - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + @expose_as_static + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -400,9 +656,10 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """ return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + @expose_as_static def array_contains_all( self, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -421,9 +678,10 @@ def array_contains_all( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) + @expose_as_static def array_contains_any( self, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -443,7 +701,8 @@ def array_contains_any( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) - def array_length(self) -> "ArrayLength": + @expose_as_static + def array_length(self) -> "Expr": """Creates an expression that calculates the length of an array. Example: @@ -455,7 +714,8 @@ def array_length(self) -> "ArrayLength": """ return ArrayLength(self) - def array_reverse(self) -> "ArrayReverse": + @expose_as_static + def array_reverse(self) -> "Expr": """Creates an expression that returns the reversed content of an array. Example: @@ -467,7 +727,8 @@ def array_reverse(self) -> "ArrayReverse": """ return ArrayReverse(self) - def is_nan(self) -> "IsNaN": + @expose_as_static + def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -479,7 +740,8 @@ def is_nan(self) -> "IsNaN": """ return IsNaN(self) - def exists(self) -> "Exists": + @expose_as_static + def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. Example: @@ -491,7 +753,8 @@ def exists(self) -> "Exists": """ return Exists(self) - def sum(self) -> "Sum": + @expose_as_static + def sum(self) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: @@ -499,24 +762,26 @@ def sum(self) -> "Sum": >>> Field.of("orderAmount").sum().as_("totalRevenue") Returns: - A new `Accumulator` representing the 'sum' aggregation. + A new `AggregateFunction` representing the 'sum' aggregation. """ return Sum(self) - def avg(self) -> "Avg": + @expose_as_static + def average(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. Example: >>> # Calculate the average age of users - >>> Field.of("age").avg().as_("averageAge") + >>> Field.of("age").average().as_("averageAge") Returns: - A new `Accumulator` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'avg' aggregation. """ - return Avg(self) + return Average(self) - def count(self) -> "Count": + + def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -525,35 +790,38 @@ def count(self) -> "Count": >>> Field.of("productId").count().as_("totalProducts") Returns: - A new `Accumulator` representing the 'count' aggregation. + A new `AggregateFunction` representing the 'count' aggregation. """ return Count(self) - def min(self) -> "Min": + @expose_as_static + def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: >>> # Find the lowest price of all products - >>> Field.of("price").min().as_("lowestPrice") + >>> Field.of("price").minimum().as_("lowestPrice") Returns: - A new `Accumulator` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'minimum' aggregation. """ - return Min(self) + return Minimum(self) - def max(self) -> "Max": + @expose_as_static + def maximum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> # Find the highest score in a leaderboard - >>> Field.of("score").max().as_("highestScore") + >>> Field.of("score").maximum().as_("highestScore") Returns: - A new `Accumulator` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'maximum' aggregation. """ - return Max(self) + return Maximum(self) - def char_length(self) -> "CharLength": + @expose_as_static + def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. Example: @@ -565,7 +833,8 @@ def char_length(self) -> "CharLength": """ return CharLength(self) - def byte_length(self) -> "ByteLength": + @expose_as_static + def byte_length(self) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -577,7 +846,8 @@ def byte_length(self) -> "ByteLength": """ return ByteLength(self) - def like(self, pattern: Expr | str) -> "Like": + @expose_as_static + def like(self, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -594,7 +864,8 @@ def like(self, pattern: Expr | str) -> "Like": """ return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) - def regex_contains(self, regex: Expr | str) -> "RegexContains": + @expose_as_static + def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -612,7 +883,8 @@ def regex_contains(self, regex: Expr | str) -> "RegexContains": """ return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) - def regex_matches(self, regex: Expr | str) -> "RegexMatch": + @expose_as_static + def regex_matches(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. Example: @@ -629,14 +901,15 @@ def regex_matches(self, regex: Expr | str) -> "RegexMatch": """ return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) - def str_contains(self, substring: Expr | str) -> "StrContains": + @expose_as_static + def string_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: >>> # Check if the 'description' field contains "example". - >>> Field.of("description").str_contains("example") + >>> Field.of("description").string_contains("example") >>> # Check if the 'description' field contains the value of the 'keyword' field. - >>> Field.of("description").str_contains(Field.of("keyword")) + >>> Field.of("description").string_contains(Field.of("keyword")) Args: substring: The substring (string or expression) to use for the search. @@ -644,9 +917,12 @@ def str_contains(self, substring: Expr | str) -> "StrContains": Returns: A new `Expr` representing the 'contains' comparison. """ - return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + return StringContains( + self, self._cast_to_expr_or_convert_to_constant(substring) + ) - def starts_with(self, prefix: Expr | str) -> "StartsWith": + @expose_as_static + def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -663,7 +939,8 @@ def starts_with(self, prefix: Expr | str) -> "StartsWith": """ return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) - def ends_with(self, postfix: Expr | str) -> "EndsWith": + @expose_as_static + def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -680,12 +957,13 @@ def ends_with(self, postfix: Expr | str) -> "EndsWith": """ return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + @expose_as_static + def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string - >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + >>> Field.of("firstName").string_concat(" ", Field.of("lastName")) Args: *elements: The expressions or constants (typically strings) to concatenate. @@ -693,16 +971,68 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": Returns: A new `Expr` representing the concatenated string. """ - return StrConcat( + return StringConcat( self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] ) - def map_get(self, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. + @expose_as_static + def to_lower(self) -> "Expr": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ + return ToLower(self) + + @expose_as_static + def to_upper(self) -> "Expr": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ + return ToUpper(self) + + @expose_as_static + def trim(self) -> "Expr": + """Creates an expression that removes leading and trailing whitespace from a string. Example: - >>> # Get the 'city' value from - >>> # the 'address' map field + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ + return Trim(self) + + @expose_as_static + def reverse(self) -> "Expr": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ + return Reverse(self) + + @expose_as_static + def map_get(self, key: str) -> "Expr": + """Accesses a value from the map produced by evaluating this expression. + + Example: + >>> Expr.map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -713,7 +1043,98 @@ def map_get(self, key: str) -> "MapGet": """ return MapGet(self, Constant.of(key)) - def vector_length(self) -> "VectorLength": + @expose_as_static + def map_remove(self, key: str) -> "Expr": + """Remove a key from a the map produced by evaluating this expression. + + Example: + >>> Expr.map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to ewmove in the map. + + Returns: + A new `Expr` representing the map_remove operation. + """ + return MapRemove(self, Constant.of(key)) + + @expose_as_static + def map_merge(self, *other_maps: Expr | dict[str, Expr | CONSTANT_TYPE])-> "Expr": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + >>> Expr.map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + map_list = [] + for map in other_maps: + map_list.append(map if isinstance(map, Expr) else Expr.map(map)) + return MapMerge(self, *map_list) + + + @expose_as_static + def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ + return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + @expose_as_static + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ + return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + @expose_as_static + def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ + return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) + + @expose_as_static + def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -725,7 +1146,8 @@ def vector_length(self) -> "VectorLength": """ return VectorLength(self) - def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + @expose_as_static + def timestamp_to_unix_micros(self) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -740,7 +1162,8 @@ def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": """ return TimestampToUnixMicros(self) - def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + @expose_as_static + def unix_micros_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -753,7 +1176,8 @@ def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": """ return UnixMicrosToTimestamp(self) - def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + @expose_as_static + def timestamp_to_unix_millis(self) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -768,7 +1192,8 @@ def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": """ return TimestampToUnixMillis(self) - def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + @expose_as_static + def unix_millis_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -781,7 +1206,8 @@ def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": """ return UnixMillisToTimestamp(self) - def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + @expose_as_static + def timestamp_to_unix_seconds(self) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -796,7 +1222,8 @@ def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": """ return TimestampToUnixSeconds(self) - def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + @expose_as_static + def unix_seconds_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -809,7 +1236,8 @@ def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": """ return UnixSecondsToTimestamp(self) - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + @expose_as_static + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -832,14 +1260,15 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd self._cast_to_expr_or_convert_to_constant(amount), ) - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + @expose_as_static + def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> Field.of("timestamp").timestamp_subtract(Field.of("unit"), Field.of("amount")) >>> # Subtract 2.5 hours from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + >>> Field.of("timestamp").timestamp_subtract("hour", 2.5) Args: unit: The expression or string evaluating to the unit of time to subtract, must be one of @@ -849,7 +1278,7 @@ def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampSub( + return TimestampSubtract( self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount), @@ -860,7 +1289,7 @@ def ascending(self) -> Ordering: Example: >>> # Sort documents by the 'name' field in ascending order - >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + >>> client.pipeline().collection("users").sort(Field.of("name").ascending()) Returns: A new `Ordering` for ascending sorting. @@ -872,14 +1301,14 @@ def descending(self) -> Ordering: Example: >>> # Sort documents by the 'createdAt' field in descending order - >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + >>> client.pipeline().collection("users").sort(Field.of("createdAt").descending()) Returns: A new `Ordering` for descending sorting. """ return Ordering(self, Ordering.Direction.DESCENDING) - def as_(self, alias: str) -> "ExprWithAlias": + def as_(self, alias: str) -> "AliasedExpr": """Assigns an alias to this expression. Aliases are useful for renaming fields in the output of a stage or for giving meaningful @@ -887,7 +1316,7 @@ def as_(self, alias: str) -> "ExprWithAlias": Example: >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. - >>> firestore.pipeline().collection("items").add_fields( + >>> client.pipeline().collection("items").add_fields( ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") ... ) @@ -895,10 +1324,10 @@ def as_(self, alias: str) -> "ExprWithAlias": alias: The alias to assign to this expression. Returns: - A new `Selectable` (typically an `ExprWithAlias`) that wraps this + A new `Selectable` (typically an `AliasedExpr`) that wraps this expression and associates it with the provided alias. """ - return ExprWithAlias(self, alias) + return AliasedExpr(self, alias) class Constant(Expr, Generic[CONSTANT_TYPE]): @@ -921,6 +1350,9 @@ def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: def __repr__(self): return f"Constant.of({self.value!r})" + def __hash__(self): + return hash(self.value) + def _to_pb(self) -> Value: return encode_value(self.value) @@ -968,846 +1400,202 @@ def _to_pb(self): } ) - def add(left: Expr | str, right: Expr | float) -> "Add": - """Creates an expression that adds two expressions together. - Example: - >>> Function.add("rating", 5) - >>> Function.add(Field.of("quantity"), Field.of("reserve")) +class Divide(Function): + """Represents the division function.""" - Args: - left: The first expression or field path to add. - right: The second expression or constant value to add. + def __init__(self, left: Expr, right: Expr): + super().__init__("divide", [left, right]) - Returns: - A new `Expr` representing the addition operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.add(left_expr, right) - def subtract(left: Expr | str, right: Expr | float) -> "Subtract": - """Creates an expression that subtracts another expression or constant from this expression. +class DotProduct(Function): + """Represents the vector dot product function.""" - Example: - >>> Function.subtract("total", 20) - >>> Function.subtract(Field.of("price"), Field.of("discount")) + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("dot_product", [vector1, vector2]) - Args: - left: The expression or field path to subtract from. - right: The expression or constant value to subtract. - Returns: - A new `Expr` representing the subtraction operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.subtract(left_expr, right) +class EuclideanDistance(Function): + """Represents the vector Euclidean distance function.""" - def multiply(left: Expr | str, right: Expr | float) -> "Multiply": - """Creates an expression that multiplies this expression by another expression or constant. + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("euclidean_distance", [vector1, vector2]) - Example: - >>> Function.multiply("value", 2) - >>> Function.multiply(Field.of("quantity"), Field.of("price")) - Args: - left: The expression or field path to multiply. - right: The expression or constant value to multiply by. +class LogicalMaximum(Function): + """ + Returns the larger value between this expression and another expression or constant, + based on Firestore's value type ordering. + """ - Returns: - A new `Expr` representing the multiplication operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.multiply(left_expr, right) + def __init__(self, left: Expr, right: Expr): + super().__init__("max", [left, right]) - def divide(left: Expr | str, right: Expr | float) -> "Divide": - """Creates an expression that divides this expression by another expression or constant. - Example: - >>> Function.divide("value", 10) - >>> Function.divide(Field.of("total"), Field.of("count")) +class LogicalMinimum(Function): + """ + Returns the smaller value between this expression and another expression or constant, + based on Firestore's value type ordering. + """ - Args: - left: The expression or field path to be divided. - right: The expression or constant value to divide by. + def __init__(self, left: Expr, right: Expr): + super().__init__("min", [left, right]) - Returns: - A new `Expr` representing the division operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.divide(left_expr, right) - def mod(left: Expr | str, right: Expr | float) -> "Mod": - """Creates an expression that calculates the modulo (remainder) to another expression or constant. +class Map(Function): + """Creates an expression that creates a Firestore map value from an input dict.""" - Example: - >>> Function.mod("value", 5) - >>> Function.mod(Field.of("value"), Field.of("divisor")) + def __init__(self, elements: dict[Constant[str], Expr]): + element_list = [] + for k,v in elements.items(): + element_list.append(k) + element_list.append(v) + super().__init__("map", element_list) - Args: - left: The dividend expression or field path. - right: The divisor expression or constant. + def __repr__(self): + d = {a:b for a, b in zip(self.params[::2], self.params[1::2])} + return f"Map({d})" - Returns: - A new `Expr` representing the modulo operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.mod(left_expr, right) - def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": - """Creates an expression that returns the larger value between this expression - and another expression or constant, based on Firestore's value type ordering. +class MapGet(Function): + """Creates an expression that accesses a map value by key.""" - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_get", [map_, key]) - Example: - >>> Function.logical_max("value", 10) - >>> Function.logical_max(Field.of("discount"), Field.of("cap")) - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. +class MapMerge(Function): + """Creates an expression that merges multiple map values.""" - Returns: - A new `Expr` representing the logical max operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_max(left_expr, right) + def __init__(self, *maps: Expr): + super().__init__("map_merge", [*maps]) - def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": - """Creates an expression that returns the smaller value between this expression - and another expression or constant, based on Firestore's value type ordering. - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering +class MapRemove(Function): + """Creates an expression that removes a key from a map.""" - Example: - >>> Function.logical_min("value", 10) - >>> Function.logical_min(Field.of("discount"), Field.of("floor")) + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_remove", [map_, key]) - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. - Returns: - A new `Expr` representing the logical min operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_min(left_expr, right) - def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": - """Creates an expression that checks if this expression is equal to another - expression or constant value. +class Mod(Function): + """Represents the modulo function.""" - Example: - >>> Function.eq("city", "London") - >>> Function.eq(Field.of("age"), 21) + def __init__(self, left: Expr, right: Expr): + super().__init__("mod", [left, right]) - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for equality. - Returns: - A new `Expr` representing the equality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.eq(left_expr, right) +class Multiply(Function): + """Represents the multiplication function.""" - def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": - """Creates an expression that checks if this expression is not equal to another - expression or constant value. - - Example: - >>> Function.neq("country", "USA") - >>> Function.neq(Field.of("status"), "completed") - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for inequality. - - Returns: - A new `Expr` representing the inequality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.neq(left_expr, right) - - def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": - """Creates an expression that checks if this expression is greater than another - expression or constant value. - - Example: - >>> Function.gt("price", 100) - >>> Function.gt(Field.of("age"), Field.of("limit")) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than. - - Returns: - A new `Expr` representing the greater than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gt(left_expr, right) - - def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": - """Creates an expression that checks if this expression is greater than or equal - to another expression or constant value. - - Example: - >>> Function.gte("score", 80) - >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than or equal to. - - Returns: - A new `Expr` representing the greater than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gte(left_expr, right) - - def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": - """Creates an expression that checks if this expression is less than another - expression or constant value. - - Example: - >>> Function.lt("price", 50) - >>> Function.lt(Field.of("age"), Field.of('limit')) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than. - - Returns: - A new `Expr` representing the less than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lt(left_expr, right) - - def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": - """Creates an expression that checks if this expression is less than or equal to - another expression or constant value. - - Example: - >>> Function.lte("score", 70) - >>> Function.lte(Field.of("quantity"), Constant.of(20)) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than or equal to. - - Returns: - A new `Expr` representing the less than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lte(left_expr, right) - - def in_any(left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE]) -> "In": - """Creates an expression that checks if this expression is equal to any of the - provided values or expressions. - - Example: - >>> Function.in_any("category", ["Electronics", "Apparel"]) - >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) - - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. - - Returns: - A new `Expr` representing the 'IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.in_any(left_expr, array) - - def not_in_any(left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not": - """Creates an expression that checks if this expression is not equal to any of the - provided values or expressions. - - Example: - >>> Function.not_in_any("status", ["pending", "cancelled"]) - - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. - - Returns: - A new `Expr` representing the 'NOT IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_in_any(left_expr, array) - - def array_contains( - array: Expr | str, element: Expr | CONSTANT_TYPE - ) -> "ArrayContains": - """Creates an expression that checks if an array contains a specific element or value. - - Example: - >>> Function.array_contains("colors", "red") - >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) - - Args: - array: The array expression or field path to check. - element: The element (expression or constant) to search for in the array. - - Returns: - A new `Expr` representing the 'array_contains' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains(array_expr, element) - - def array_contains_all( - array: Expr | str, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": - """Creates an expression that checks if an array contains all the specified elements. - - Example: - >>> Function.array_contains_all("tags", ["news", "sports"]) - >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_all' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_all(array_expr, elements) - - def array_contains_any( - array: Expr | str, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": - """Creates an expression that checks if an array contains any of the specified elements. - - Example: - >>> Function.array_contains_any("groups", ["admin", "editor"]) - >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_any' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_any(array_expr, elements) - - def array_length(array: Expr | str) -> "ArrayLength": - """Creates an expression that calculates the length of an array. - - Example: - >>> Function.array_length("cart") - - Returns: - A new `Expr` representing the length of the array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_length(array_expr) - - def array_reverse(array: Expr | str) -> "ArrayReverse": - """Creates an expression that returns the reversed content of an array. - - Example: - >>> Function.array_reverse("preferences") - - Returns: - A new `Expr` representing the reversed array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_reverse(array_expr) - - def is_nan(expr: Expr | str) -> "IsNaN": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - - Example: - >>> Function.is_nan("measurement") - - Returns: - A new `Expr` representing the 'isNaN' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.is_nan(expr_val) - - def exists(expr: Expr | str) -> "Exists": - """Creates an expression that checks if a field exists in the document. - - Example: - >>> Function.exists("phoneNumber") - - Returns: - A new `Expr` representing the 'exists' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.exists(expr_val) - - def sum(expr: Expr | str) -> "Sum": - """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. - - Example: - >>> Function.sum("orderAmount") - - Returns: - A new `Accumulator` representing the 'sum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.sum(expr_val) - - def avg(expr: Expr | str) -> "Avg": - """Creates an aggregation that calculates the average (mean) of a numeric field across multiple - stage inputs. - - Example: - >>> Function.avg("age") - - Returns: - A new `Accumulator` representing the 'avg' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.avg(expr_val) - - def count(expr: Expr | str | None = None) -> "Count": - """Creates an aggregation that counts the number of stage inputs with valid evaluations of the - expression or field. If no expression is provided, it counts all inputs. - - Example: - >>> Function.count("productId") - >>> Function.count() - - Returns: - A new `Accumulator` representing the 'count' aggregation. - """ - if expr is None: - return Count() - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.count(expr_val) - - def min(expr: Expr | str) -> "Min": - """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. - - Example: - >>> Function.min("price") - - Returns: - A new `Accumulator` representing the 'min' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.min(expr_val) - - def max(expr: Expr | str) -> "Max": - """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. - - Example: - >>> Function.max("score") - - Returns: - A new `Accumulator` representing the 'max' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.max(expr_val) - - def char_length(expr: Expr | str) -> "CharLength": - """Creates an expression that calculates the character length of a string. - - Example: - >>> Function.char_length("name") - - Returns: - A new `Expr` representing the length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.char_length(expr_val) - - def byte_length(expr: Expr | str) -> "ByteLength": - """Creates an expression that calculates the byte length of a string in its UTF-8 form. - - Example: - >>> Function.byte_length("name") - - Returns: - A new `Expr` representing the byte length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.byte_length(expr_val) - - def like(expr: Expr | str, pattern: Expr | str) -> "Like": - """Creates an expression that performs a case-sensitive string comparison. - - Example: - >>> Function.like("title", "%guide%") - >>> Function.like(Field.of("title"), Field.of("pattern")) - - Args: - expr: The expression or field path to perform the comparison on. - pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. - - Returns: - A new `Expr` representing the 'like' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.like(expr_val, pattern) - - def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": - """Creates an expression that checks if a string contains a specified regular expression as a - substring. - - Example: - >>> Function.regex_contains("description", "(?i)example") - >>> Function.regex_contains(Field.of("description"), Field.of("regex")) - - Args: - expr: The expression or field path to perform the comparison on. - regex: The regular expression (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_contains(expr_val, regex) - - def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": - """Creates an expression that checks if a string matches a specified regular expression. - - Example: - >>> # Check if the 'email' field matches a valid email pattern - >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") - >>> Function.regex_matches(Field.of("email"), Field.of("regex")) - - Args: - expr: The expression or field path to match against. - regex: The regular expression (string or expression) to use for the match. - - Returns: - A new `Expr` representing the regular expression match. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_matches(expr_val, regex) - - def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": - """Creates an expression that checks if this string expression contains a specified substring. - - Example: - >>> Function.str_contains("description", "example") - >>> Function.str_contains(Field.of("description"), Field.of("keyword")) - - Args: - expr: The expression or field path to perform the comparison on. - substring: The substring (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.str_contains(expr_val, substring) - - def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": - """Creates an expression that checks if a string starts with a given prefix. - - Example: - >>> Function.starts_with("name", "Mr.") - >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) - - Args: - expr: The expression or field path to check. - prefix: The prefix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'starts with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.starts_with(expr_val, prefix) - - def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": - """Creates an expression that checks if a string ends with a given postfix. - - Example: - >>> Function.ends_with("filename", ".txt") - >>> Function.ends_with(Field.of("url"), Field.of("extension")) - - Args: - expr: The expression or field path to check. - postfix: The postfix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'ends with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ends_with(expr_val, postfix) - - def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": - """Creates an expression that concatenates string expressions, fields or constants together. - - Example: - >>> Function.str_concat("firstName", " ", Field.of("lastName")) - - Args: - first: The first expression or field path to concatenate. - *elements: The expressions or constants (typically strings) to concatenate. - - Returns: - A new `Expr` representing the concatenated string. - """ - first_expr = Field.of(first) if isinstance(first, str) else first - return Expr.str_concat(first_expr, *elements) - - def map_get(map_expr: Expr | str, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. - - Example: - >>> Function.map_get("address", "city") - - Args: - map_expr: The expression or field path of the map. - key: The key to access in the map. - - Returns: - A new `Expr` representing the value associated with the given key in the map. - """ - map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr - return Expr.map_get(map_val, key) - - def vector_length(vector_expr: Expr | str) -> "VectorLength": - """Creates an expression that calculates the length (dimension) of a Firestore Vector. - - Example: - >>> Function.vector_length("embedding") - - Returns: - A new `Expr` representing the length of the vector. - """ - vector_val = ( - Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr - ) - return Expr.vector_length(vector_val) - - def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": - """Creates an expression that converts a timestamp to the number of microseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the microsecond. - - Example: - >>> Function.timestamp_to_unix_micros("timestamp") - - Returns: - A new `Expr` representing the number of microseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_micros(timestamp_val) - - def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": - """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_micros_to_timestamp("microseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - micros_val = ( - Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr - ) - return Expr.unix_micros_to_timestamp(micros_val) - - def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": - """Creates an expression that converts a timestamp to the number of milliseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the millisecond. - - Example: - >>> Function.timestamp_to_unix_millis("timestamp") - - Returns: - A new `Expr` representing the number of milliseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_millis(timestamp_val) + def __init__(self, left: Expr, right: Expr): + super().__init__("multiply", [left, right]) - def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": - """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - Example: - >>> Function.unix_millis_to_timestamp("milliseconds") +class Parent(Function): + """Represents getting the parent document reference.""" - Returns: - A new `Expr` representing the timestamp. - """ - millis_val = ( - Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr - ) - return Expr.unix_millis_to_timestamp(millis_val) + def __init__(self, value: Expr): + super().__init__("parent", [value]) - def timestamp_to_unix_seconds( - timestamp_expr: Expr | str, - ) -> "TimestampToUnixSeconds": - """Creates an expression that converts a timestamp to the number of seconds since the epoch - (1970-01-01 00:00:00 UTC). - Truncates higher levels of precision by rounding down to the beginning of the second. +class Reverse(Function): + """Represents reversing a string.""" - Example: - >>> Function.timestamp_to_unix_seconds("timestamp") + def __init__(self, expr: Expr): + super().__init__("reverse", [expr]) - Returns: - A new `Expr` representing the number of seconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_seconds(timestamp_val) - def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": - """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 - UTC) to a timestamp. +class StringConcat(Function): + """Represents concatenating multiple strings.""" - Example: - >>> Function.unix_seconds_to_timestamp("seconds") + def __init__(self, *exprs: Expr): + super().__init__("string_concat", exprs) - Returns: - A new `Expr` representing the timestamp. - """ - seconds_val = ( - Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr - ) - return Expr.unix_seconds_to_timestamp(seconds_val) - def timestamp_add( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampAdd": - """Creates an expression that adds a specified amount of time to this timestamp expression. +class Subtract(Function): + """Represents the subtraction function.""" - Example: - >>> Function.timestamp_add("timestamp", "day", 1.5) - >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) + def __init__(self, left: Expr, right: Expr): + super().__init__("subtract", [left, right]) - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to add, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to add. - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_add(timestamp_expr, unit, amount) +class TimestampAdd(Function): + """Represents adding a duration to a timestamp.""" - def timestamp_sub( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampSub": - """Creates an expression that subtracts a specified amount of time from this timestamp expression. + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_add", [timestamp, unit, amount]) - Example: - >>> Function.timestamp_sub("timestamp", "hour", 2.5) - >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to subtract, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to subtract. +class Abs(Function): + """Represents the absolute value function.""" - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_sub(timestamp_expr, unit, amount) + def __init__(self, value: Expr): + super().__init__("abs", [value]) -class Divide(Function): - """Represents the division function.""" +class Ceil(Function): + """Represents the ceiling function.""" - def __init__(self, left: Expr, right: Expr): - super().__init__("divide", [left, right]) + def __init__(self, value: Expr): + super().__init__("ceil", [value]) -class LogicalMax(Function): - """Represents the logical maximum function based on Firestore type ordering.""" +class Exp(Function): + """Represents the exponential function.""" - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_maximum", [left, right]) + def __init__(self, value: Expr): + super().__init__("exp", [value]) -class LogicalMin(Function): - """Represents the logical minimum function based on Firestore type ordering.""" +class Floor(Function): + """Represents the floor function.""" - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_minimum", [left, right]) + def __init__(self, value: Expr): + super().__init__("floor", [value]) -class MapGet(Function): - """Represents accessing a value within a map by key.""" +class Ln(Function): + """Represents the natural logarithm function.""" - def __init__(self, map_: Expr, key: Constant[str]): - super().__init__("map_get", [map_, key]) + def __init__(self, value: Expr): + super().__init__("ln", [value]) -class Mod(Function): - """Represents the modulo function.""" +class Log(Function): + """Represents the logarithm function.""" - def __init__(self, left: Expr, right: Expr): - super().__init__("mod", [left, right]) + def __init__(self, value: Expr, base: Expr): + super().__init__("log", [value, base]) -class Multiply(Function): - """Represents the multiplication function.""" +class Pow(Function): + """Represents the power function.""" - def __init__(self, left: Expr, right: Expr): - super().__init__("multiply", [left, right]) + def __init__(self, base: Expr, exponent: Expr): + super().__init__("pow", [base, exponent]) -class Parent(Function): - """Represents getting the parent document reference.""" +class Round(Function): + """Represents the round function.""" def __init__(self, value: Expr): - super().__init__("parent", [value]) - - -class StrConcat(Function): - """Represents concatenating multiple strings.""" - - def __init__(self, *exprs: Expr): - super().__init__("str_concat", exprs) - - -class Subtract(Function): - """Represents the subtraction function.""" + super().__init__("round", [value]) - def __init__(self, left: Expr, right: Expr): - super().__init__("subtract", [left, right]) +class Sqrt(Function): + """Represents the square root function.""" -class TimestampAdd(Function): - """Represents adding a duration to a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_add", [timestamp, unit, amount]) + def __init__(self, value: Expr): + super().__init__("sqrt", [value]) -class TimestampSub(Function): +class TimestampSubtract(Function): """Represents subtracting a duration from a timestamp.""" def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_sub", [timestamp, unit, amount]) + super().__init__("timestamp_subtract", [timestamp, unit, amount]) class TimestampToUnixMicros(Function): @@ -1831,6 +1619,27 @@ def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) +class ToLower(Function): + """Represents converting a string to lowercase.""" + + def __init__(self, value: Expr): + super().__init__("to_lower", [value]) + + +class ToUpper(Function): + """Represents converting a string to uppercase.""" + + def __init__(self, value: Expr): + super().__init__("to_upper", [value]) + + +class Trim(Function): + """Represents trimming whitespace from a string.""" + + def __init__(self, expr: Expr): + super().__init__("trim", [expr]) + + class UnixMicrosToTimestamp(Function): """Represents converting microseconds since epoch to a timestamp.""" @@ -1866,18 +1675,28 @@ def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) -class ArrayElement(Function): - """Represents accessing an element within an array""" +class Array(Function): + """Creates an expression that creates a Firestore array value from an input list.""" + + def __init__(self, elements: list[Expr]): + super().__init__("array", elements) + + def __repr__(self): + return f"Array({self.params})" - def __init__(self): - super().__init__("array_element", []) +class ArrayGet(Function): + """Creates an expression that indexes into an array from the beginning or end and returns an element.""" -class ArrayFilter(Function): - """Represents filtering elements from an array based on a condition.""" + def __init__(self, array: Expr, index: Expr): + super().__init__("array_get", [array, index]) - def __init__(self, array: Expr, filter: "FilterCondition"): - super().__init__("array_filter", [array, filter]) + +class ArrayConcat(Function): + """Represents concatenating multiple arrays.""" + + def __init__(self, array: Expr, rest: Sequence[Expr]): + super().__init__("array_concat", [array] + rest) class ArrayLength(Function): @@ -1894,13 +1713,6 @@ def __init__(self, array: Expr): super().__init__("array_reverse", [array]) -class ArrayTransform(Function): - """Represents applying a transformation function to each element of an array.""" - - def __init__(self, array: Expr, transform: Function): - super().__init__("array_transform", [array, transform]) - - class ByteLength(Function): """Represents getting the byte length of a string (UTF-8).""" @@ -1922,39 +1734,60 @@ def __init__(self, value: Expr): super().__init__("collection_id", [value]) -class Accumulator(Function): +class CosineDistance(Function): + """Represents the vector cosine distance function.""" + + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("cosine_distance", [vector1, vector2]) + + +class AggregateFunction(Function): """A base class for aggregation functions that operate across multiple inputs.""" + def as_(self, alias: str) -> "AliasedAggregate": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. -class Max(Accumulator): - """Represents the maximum aggregation function.""" + Args: + alias: The alias to assign to this expression. + + Returns: A new AliasedAggregate that wraps this expression and associates it with the + provided alias. + """ + return AliasedAggregate(self, alias) + + +class Maximum(AggregateFunction): + """Finds the maximum value of a field, aggregated across multiple stage inputs.""" def __init__(self, value: Expr): - super().__init__("maximum", [value]) + super().__init__("max", [value]) -class Min(Accumulator): - """Represents the minimum aggregation function.""" +class Minimum(AggregateFunction): + """Finds the maximum value of a field, aggregated across multiple stage inputs.""" def __init__(self, value: Expr): - super().__init__("minimum", [value]) + super().__init__("min", [value]) -class Sum(Accumulator): +class Sum(AggregateFunction): """Represents the sum aggregation function.""" def __init__(self, value: Expr): super().__init__("sum", [value]) -class Avg(Accumulator): +class Average(AggregateFunction): """Represents the average aggregation function.""" def __init__(self, value: Expr): - super().__init__("avg", [value]) + super().__init__("average", [value]) -class Count(Accumulator): +class Count(AggregateFunction): """Represents an aggregation that counts the total number of inputs.""" def __init__(self, value: Expr | None = None): @@ -2000,7 +1833,7 @@ def _to_value(field_list: Sequence[Selectable]) -> Value: T = TypeVar("T", bound=Expr) -class ExprWithAlias(Selectable, Generic[T]): +class AliasedExpr(Selectable, Generic[T]): """Wraps an expression with an alias.""" def __init__(self, expr: T, alias: str): @@ -2017,6 +1850,23 @@ def _to_pb(self): return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) +class AliasedAggregate: + """Wraps an aggregate with an alias""" + + def __init__(self, expr: AggregateFunction, alias: str): + self.expr = expr + self.alias = alias + + def _to_map(self): + return self.alias, self.expr._to_pb() + + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" + + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) + + class Field(Selectable): """Represents a reference to a field within a document.""" @@ -2054,7 +1904,7 @@ def _to_pb(self): return Value(field_reference_value=self.path) -class FilterCondition(Function): +class BooleanExpr(Function): """Filters the given data in some way.""" def __init__( @@ -2070,7 +1920,7 @@ def __init__( def __repr__(self): """ - Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). + Most BooleanExprs can be triggered infix. Eg: Field.of('age').greater_than(18). Display them this way in the repr string where possible """ @@ -2086,8 +1936,7 @@ def __repr__(self): def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - FilterCondition._from_query_filter_pb(f, client) - for f in filter_pb.filters + BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) @@ -2104,34 +1953,34 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.eq(None)) + return And(field.exists(), field.equal(None)) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.eq(None))) + return And(field.exists(), Not(field.equal(None))) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): field = Field.of(filter_pb.field.field_path) value = decode_value(filter_pb.value, client) if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: - return And(field.exists(), field.lt(value)) + return And(field.exists(), field.less_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: - return And(field.exists(), field.lte(value)) + return And(field.exists(), field.less_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: - return And(field.exists(), field.gt(value)) + return And(field.exists(), field.greater_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: - return And(field.exists(), field.gte(value)) + return And(field.exists(), field.greater_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: - return And(field.exists(), field.eq(value)) + return And(field.exists(), field.equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.neq(value)) + return And(field.exists(), field.not_equal(value)) if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: return And(field.exists(), field.array_contains(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: return And(field.exists(), field.array_contains_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: - return And(field.exists(), field.in_any(value)) + return And(field.exists(), field.equal_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_in_any(value)) + return And(field.exists(), field.not_equal_any(value)) else: raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): @@ -2141,165 +1990,170 @@ def _from_query_filter_pb(filter_pb, client): or filter_pb.field_filter or filter_pb.unary_filter ) - return FilterCondition._from_query_filter_pb(f, client) + return BooleanExpr._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") -class And(FilterCondition): - def __init__(self, *conditions: "FilterCondition"): +class And(BooleanExpr): + def __init__(self, *conditions: "BooleanExpr"): super().__init__("and", conditions, use_infix_repr=False) -class ArrayContains(FilterCondition): +class ArrayContains(BooleanExpr): def __init__(self, array: Expr, element: Expr): super().__init__("array_contains", [array, element]) -class ArrayContainsAll(FilterCondition): +class ArrayContainsAll(BooleanExpr): """Represents checking if an array contains all specified elements.""" def __init__(self, array: Expr, elements: Sequence[Expr]): super().__init__("array_contains_all", [array, ListOfExprs(elements)]) -class ArrayContainsAny(FilterCondition): +class ArrayContainsAny(BooleanExpr): """Represents checking if an array contains any of the specified elements.""" def __init__(self, array: Expr, elements: Sequence[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) -class EndsWith(FilterCondition): +class EndsWith(BooleanExpr): """Represents checking if a string ends with a specific postfix.""" def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) -class Eq(FilterCondition): +class Equal(BooleanExpr): """Represents the equality comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("eq", [left, right]) + super().__init__("equal", [left, right]) -class Exists(FilterCondition): +class Exists(BooleanExpr): """Represents checking if a field exists.""" def __init__(self, expr: Expr): super().__init__("exists", [expr]) -class Gt(FilterCondition): +class GreaterThan(BooleanExpr): """Represents the greater than comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("gt", [left, right]) + super().__init__("greater_than", [left, right]) -class Gte(FilterCondition): +class GreaterThanOrEqual(BooleanExpr): """Represents the greater than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("gte", [left, right]) + super().__init__("greater_than_or_equal", [left, right]) -class If(FilterCondition): +class Conditional(BooleanExpr): """Represents a conditional expression (if-then-else).""" - def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): - super().__init__("if", [condition, true_expr, false_expr]) + def __init__(self, condition: "BooleanExpr", then_expr: Expr, else_expr: Expr): + super().__init__("conditional", [condition, then_expr, else_expr]) -class In(FilterCondition): +class EqualAny(BooleanExpr): """Represents checking if an expression's value is within a list of values.""" def __init__(self, left: Expr, others: Sequence[Expr]): - super().__init__( - "in", [left, ListOfExprs(others)], infix_name_override="in_any" - ) + super().__init__("equal_any", [left, ListOfExprs(others)]) + + +class NotEqualAny(BooleanExpr): + """Represents checking if an expression's value is not within a list of values.""" + + def __init__(self, left: Expr, others: Sequence[Expr]): + super().__init__("not_equal_any", [left, ListOfExprs(others)]) -class IsNaN(FilterCondition): +class IsNaN(BooleanExpr): """Represents checking if a numeric value is NaN.""" def __init__(self, value: Expr): super().__init__("is_nan", [value]) -class Like(FilterCondition): +class Like(BooleanExpr): """Represents a case-sensitive wildcard string comparison.""" def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) -class Lt(FilterCondition): +class LessThan(BooleanExpr): """Represents the less than comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("lt", [left, right]) + super().__init__("less_than", [left, right]) -class Lte(FilterCondition): +class LessThanOrEqual(BooleanExpr): """Represents the less than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("lte", [left, right]) + super().__init__("less_than_or_equal", [left, right]) -class Neq(FilterCondition): +class NotEqual(BooleanExpr): """Represents the inequality comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("neq", [left, right]) + super().__init__("not_equal", [left, right]) -class Not(FilterCondition): +class Not(BooleanExpr): """Represents the logical NOT of a filter condition.""" def __init__(self, condition: Expr): super().__init__("not", [condition], use_infix_repr=False) -class Or(FilterCondition): +class Or(BooleanExpr): """Represents the logical OR of multiple filter conditions.""" - def __init__(self, *conditions: "FilterCondition"): + def __init__(self, *conditions: "BooleanExpr"): super().__init__("or", conditions) -class RegexContains(FilterCondition): +class RegexContains(BooleanExpr): """Represents checking if a string contains a substring matching a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_contains", [expr, regex]) -class RegexMatch(FilterCondition): +class RegexMatch(BooleanExpr): """Represents checking if a string fully matches a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_match", [expr, regex]) -class StartsWith(FilterCondition): +class StartsWith(BooleanExpr): """Represents checking if a string starts with a specific prefix.""" def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) -class StrContains(FilterCondition): +class StringContains(BooleanExpr): """Represents checking if a string contains a specific substring.""" def __init__(self, expr: Expr, substring: Expr): - super().__init__("str_contains", [expr, substring]) + super().__init__("string_contains", [expr, substring]) -class Xor(FilterCondition): +class Xor(BooleanExpr): """Represents the logical XOR of multiple filter conditions.""" - def __init__(self, conditions: Sequence["FilterCondition"]): + def __init__(self, conditions: Sequence["BooleanExpr"]): super().__init__("xor", conditions, use_infix_repr=False) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc262f4a9..59206d90c 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -214,6 +214,43 @@ tests: accumulators: [] groups: [genre] assert_error: ".* requires at least one accumulator" + - description: testDistinct + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1900 + - Distinct: + - ExprWithAlias: + - ToLower: + - Field: genre + - "lower_genre" + assert_results: + - lower_genre: romance + - lower_genre: psychological thriller + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: where + - args: + - mapValue: + fields: + lower_genre: + functionValue: + args: + - fieldReferenceValue: genre + name: to_lower + name: distinct - description: testGroupBysAndAggregate pipeline: - Collection: books @@ -776,6 +813,44 @@ tests: - integerValue: '3' name: eq name: where + - description: testArrayConcat + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ArrayConcat: + - Field: tags + - - Constant: newTag1 + - Constant: newTag2 + - "modifiedTags" + - Limit: 1 + assert_results: + - modifiedTags: + - comedy + - space + - adventure + - newTag1 + - newTag2 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + modifiedTags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: newTag1 + - stringValue: newTag2 + name: array_concat + name: select + - args: + - integerValue: '1' + name: limit - description: testStrConcat pipeline: - Collection: books @@ -967,6 +1042,122 @@ tests: expression: fieldReferenceValue: title name: sort + - description: testStringFunctions - Reverse + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - Reverse: + - Field: title + - "reversed_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - reversed_title: yxalaG ot ediug s'reknhiHcH ehT + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: reverse + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - description: testStringFunctions - ReplaceFirst + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ReplaceFirst: + - Field: title + - Constant: The + - Constant: A + - "replaced_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - replaced_title: A Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + - stringValue: A + name: replace_first + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - description: testStringFunctions - ReplaceAll + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ReplaceAll: + - Field: title + - Constant: " " + - Constant: "_" + - "replaced_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: ' ' + - stringValue: _ + name: replace_all + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - CharLength pipeline: - Collection: books @@ -1045,6 +1236,115 @@ tests: name: str_concat name: byte_length name: select + - description: testToLowercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ToLower: + - Field: title + - "lowercaseTitle" + - Limit: 1 + assert_results: + - lowercaseTitle: the hitchhiker's guide to the galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + lowercaseTitle: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - args: + - integerValue: '1' + name: limit + - description: testToUppercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ToUpper: + - Field: author + - "uppercaseAuthor" + - Limit: 1 + assert_results: + - uppercaseAuthor: DOUGLAS ADAMS + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + uppercaseAuthor: + functionValue: + args: + - fieldReferenceValue: author + name: to_upper + name: select + - args: + - integerValue: '1' + name: limit + - description: testTrim + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + - StrConcat: + - Constant: " " + - Field: title + - Constant: " " + - "spacedTitle" + - Select: + - ExprWithAlias: + - Trim: + - Field: spacedTitle + - "trimmedTitle" + - spacedTitle + - Limit: 1 + assert_results: + - trimmedTitle: The Hitchhiker's Guide to the Galaxy + spacedTitle: " The Hitchhiker's Guide to the Galaxy " + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + spacedTitle: + functionValue: + args: + - stringValue: ' ' + - fieldReferenceValue: title + - stringValue: ' ' + name: str_concat + name: add_fields + - args: + - mapValue: + fields: + spacedTitle: + fieldReferenceValue: spacedTitle + trimmedTitle: + functionValue: + args: + - fieldReferenceValue: spacedTitle + name: trim + name: select + - args: + - integerValue: '1' + name: limit - description: testLike pipeline: - Collection: books @@ -1356,6 +1656,11 @@ tests: - IsNaN: - Field: rating - Select: + - ExprWithAlias: + - Eq: + - Field: rating + - Constant: null + - "ratingIsNull" - ExprWithAlias: - Not: - IsNaN: @@ -1363,7 +1668,8 @@ tests: - "ratingIsNotNaN" - Limit: 1 assert_results: - - ratingIsNotNaN: true + - ratingIsNull: false + ratingIsNotNaN: true assert_proto: pipeline: stages: @@ -1390,6 +1696,12 @@ tests: - fieldReferenceValue: rating name: is_nan name: not + ratingIsNull: + functionValue: + args: + - fieldReferenceValue: rating + - nullValue: null + name: eq name: select - args: - integerValue: '1' @@ -1500,6 +1812,79 @@ tests: - booleanValue: true name: eq name: where + - description: testDistanceFunctions + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - CosineDistance: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "cosineDistance" + - ExprWithAlias: + - DotProduct: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "dotProductDistance" + - ExprWithAlias: + - EuclideanDistance: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "euclideanDistance" + - Limit: 1 + assert_results: + - cosineDistance: 0.02560880430538015 + dotProductDistance: 0.13 + euclideanDistance: 0.806225774829855 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + cosineDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: cosine_distance + dotProductDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: dot_product + euclideanDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: euclidean_distance + name: select + - args: + - integerValue: '1' + name: limit - description: testNestedFields pipeline: - Collection: books @@ -1548,6 +1933,43 @@ tests: title: fieldReferenceValue: title name: select + - description: testReplace + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Replace: awards + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + hugo: true + nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace - description: testSampleLimit pipeline: - Collection: books @@ -1580,6 +2002,145 @@ tests: - doubleValue: 0.6 - stringValue: percent name: sample + - description: testMathFunctions - Abs + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Abs: + - Subtract: + - Field: rating + - Constant: 5 + - "absRating" + assert_results: + - absRating: 0.8 + - description: testMathFunctions - Ceil + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Ceil: + - Field: rating + - "ceilRating" + assert_results: + - ceilRating: 5 + - description: testMathFunctions - Floor + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Floor: + - Field: rating + - "floorRating" + assert_results: + - floorRating: 4 + - description: testMathFunctions - Exp + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Lord of the Rings" + - Limit: 1 + - Select: + - ExprWithAlias: + - Exp: + - Field: rating + - "expRating" + assert_results: + - expRating: 109.94717245212352 + - description: testMathFunctions - Pow + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Pow: + - Field: rating + - Constant: 2 + - "powerRating" + assert_results: + - powerRating: 17.64 + - description: testMathFunctions - Round + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Round: + - Field: rating + - "roundedRating" + assert_results: + - roundedRating: 4 + - description: testMathFunctions - Ln + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Ln: + - Field: rating + - "lnRating" + assert_results: + - lnRating: 1.4350845252893227 + - description: testMathFunctions - Log + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Log: + - Field: rating + - Constant: 10 + - "logRating" + assert_results: + - logRating: 0.6232492903979004 + - description: testMathFunctions - Sqrt + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Sqrt: + - Field: rating + - "sqrtRating" + assert_results: + - sqrtRating: 2.04939015319192 - description: testUnion pipeline: - Collection: books diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 5064e87ae..46c2dd4f0 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -127,12 +127,12 @@ def test_avg_aggregation_no_alias_to_pb(): "in_alias,expected_alias", [("total", "total"), (None, "field_1")] ) def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate from google.cloud.firestore_v1.pipeline_expressions import Count count_aggregation = CountAggregation(alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias assert isinstance(got.expr, Count) assert len(got.expr.params) == 0 @@ -143,12 +143,12 @@ def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate from google.cloud.firestore_v1.pipeline_expressions import Sum count_aggregation = SumAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias assert isinstance(got.expr, Sum) assert got.expr.params[0].path == expected_path @@ -159,14 +159,14 @@ def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alia [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate + from google.cloud.firestore_v1.pipeline_expressions import Average count_aggregation = AvgAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias - assert isinstance(got.expr, Avg) + assert isinstance(got.expr, Average) assert got.expr.params[0].path == expected_path @@ -1068,7 +1068,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import Average client = make_client() parent = client.collection("dee") @@ -1083,7 +1083,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert isinstance(aggregate_stage.accumulators[0].expr, Average) expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -1136,13 +1136,13 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + from google.cloud.firestore_v1.pipeline_expressions import Sum, Average, Count client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -1163,7 +1163,7 @@ def test_aggreation_to_pipeline_complex(): assert aggregate_stage.accumulators[0].alias == "alias" assert isinstance(aggregate_stage.accumulators[1].expr, Count) assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert isinstance(aggregate_stage.accumulators[2].expr, Average) assert aggregate_stage.accumulators[2].alias == "field_2" assert isinstance(aggregate_stage.accumulators[3].expr, Sum) assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index fdd4a1450..c69d44dd8 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -742,7 +742,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import Average client = make_async_client() parent = client.collection("dee") @@ -757,7 +757,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert isinstance(aggregate_stage.accumulators[0].expr, Average) expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -810,13 +810,13 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_async_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + from google.cloud.firestore_v1.pipeline_expressions import Sum, Average, Count client = make_async_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -837,7 +837,7 @@ def test_async_aggreation_to_pipeline_complex(): assert aggregate_stage.accumulators[0].alias == "alias" assert isinstance(aggregate_stage.accumulators[1].expr, Count) assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert isinstance(aggregate_stage.accumulators[2].expr, Average) assert aggregate_stage.accumulators[2].alias == "field_2" assert isinstance(aggregate_stage.accumulators[3].expr, Sum) assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 9bb3e61f8..c13efbfa8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2040,9 +2040,7 @@ def test__query_pipeline_composite_filter(): client = make_client() in_filter = FieldFilter("field_a", "==", "value_a") query = client.collection("my_col").where(filter=in_filter) - with mock.patch.object( - expr.FilterCondition, "_from_query_filter_pb" - ) as convert_mock: + with mock.patch.object(expr.BooleanExpr, "_from_query_filter_pb") as convert_mock: pipeline = query.pipeline() convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b237ad5ac..ebd3798c2 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -372,6 +372,8 @@ def test_pipeline_execute_stream_equivalence_mocked(): ), ("sort", (Field.of("n").descending(),), stages.Sort), ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("replace", (Field.of("n"),), stages.Replace), + ("replace", (Field.of("n"), stages.Replace.Mode.FULL_REPLACE), stages.Replace), ("sample", (10,), stages.Sample), ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), ("union", (_make_pipeline(),), stages.Union), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 936c0a0a9..95bee4941 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -22,7 +22,7 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint -from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr, ListOfExprs import google.cloud.firestore_v1.pipeline_expressions as expr @@ -90,16 +90,27 @@ def test_ctor(self): ("multiply", (2,), expr.Multiply), ("divide", (2,), expr.Divide), ("mod", (2,), expr.Mod), - ("logical_max", (2,), expr.LogicalMax), - ("logical_min", (2,), expr.LogicalMin), - ("eq", (2,), expr.Eq), - ("neq", (2,), expr.Neq), - ("lt", (2,), expr.Lt), - ("lte", (2,), expr.Lte), - ("gt", (2,), expr.Gt), - ("gte", (2,), expr.Gte), - ("in_any", ([None],), expr.In), - ("not_in_any", ([None],), expr.Not), + ("abs", (), expr.Abs), + ("ceil", (), expr.Ceil), + ("exp", (), expr.Exp), + ("floor", (), expr.Floor), + ("ln", (), expr.Ln), + ("log", (10,), expr.Log), + ("pow", (2,), expr.Pow), + ("round", (), expr.Round), + ("sqrt", (), expr.Sqrt), + ("logical_maximum", (2,), expr.LogicalMaximum), + ("logical_minimum", (2,), expr.LogicalMinimum), + ("equal", (2,), expr.Equal), + ("not_equal", (2,), expr.NotEqual), + ("less_than", (2,), expr.LessThan), + ("less_than_or_equal", (2,), expr.LessThanOrEqual), + ("greater_than", (2,), expr.GreaterThan), + ("greater_than_or_equal", (2,), expr.GreaterThanOrEqual), + ("equal_any", ([None],), expr.EqualAny), + ("not_equal_any", ([None],), expr.NotEqualAny), + ("array_get", (1,), expr.ArrayGet), + ("array_concat", ([None],), expr.ArrayConcat), ("array_contains", (None,), expr.ArrayContains), ("array_contains_all", ([None],), expr.ArrayContainsAll), ("array_contains_any", ([None],), expr.ArrayContainsAny), @@ -108,20 +119,29 @@ def test_ctor(self): ("is_nan", (), expr.IsNaN), ("exists", (), expr.Exists), ("sum", (), expr.Sum), - ("avg", (), expr.Avg), + ("average", (), expr.Average), ("count", (), expr.Count), - ("min", (), expr.Min), - ("max", (), expr.Max), + ("minimum", (), expr.Minimum), + ("maximum", (), expr.Maximum), ("char_length", (), expr.CharLength), ("byte_length", (), expr.ByteLength), ("like", ("pattern",), expr.Like), ("regex_contains", ("regex",), expr.RegexContains), ("regex_matches", ("regex",), expr.RegexMatch), - ("str_contains", ("substring",), expr.StrContains), + ("string_contains", ("substring",), expr.StringContains), ("starts_with", ("prefix",), expr.StartsWith), ("ends_with", ("postfix",), expr.EndsWith), - ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), + ("string_concat", ("elem1", expr.Constant("elem2")), expr.StringConcat), + ("to_lower", (), expr.ToLower), + ("to_upper", (), expr.ToUpper), + ("trim", (), expr.Trim), + ("reverse", (), expr.Reverse), ("map_get", ("key",), expr.MapGet), + ("map_remove", ("key",), expr.MapRemove), + ("map_merge", ({"key": "value"}, ), expr.MapMerge), + ("cosine_distance", [1], expr.CosineDistance), + ("euclidean_distance", [1], expr.EuclideanDistance), + ("dot_product", [1], expr.DotProduct), ("vector_length", (), expr.VectorLength), ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), @@ -130,10 +150,10 @@ def test_ctor(self): ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), ("timestamp_add", ("day", 1), expr.TimestampAdd), - ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), + ("timestamp_subtract", ("hour", 2.5), expr.TimestampSubtract), ("ascending", (), expr.Ordering), ("descending", (), expr.Ordering), - ("as_", ("alias",), expr.ExprWithAlias), + ("as_", ("alias",), expr.AliasedExpr), ], ) @pytest.mark.parametrize( @@ -147,13 +167,15 @@ def test_ctor(self): ) def test_infix_call(self, method, args, result_cls, base_instance): """ - many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them + many BooleanExpr expressions support infix execution, and are exposed as methods on Expr. Test calling them """ method_ptr = getattr(base_instance, method) result = method_ptr(*args) assert isinstance(result, result_cls) - if isinstance(result, expr.Function) and not method == "not_in_any": + if isinstance(result, (expr.Ordering, expr.AliasedExpr)): + assert result.expr == base_instance + else: assert result.params[0] == base_instance @@ -361,7 +383,7 @@ def test_to_map(self): assert result[0] == "field1" assert result[1] == Value(field_reference_value="field1") - class TestExprWithAlias: + class TestAliasedExpr: def test_repr(self): instance = expr.Field.of("field1").as_("alias1") assert repr(instance) == "Field.of('field1').as_('alias1')" @@ -369,14 +391,14 @@ def test_repr(self): def test_ctor(self): arg = expr.Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) assert instance.expr == arg assert instance.alias == alias def test_to_pb(self): arg = expr.Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) result = instance._to_pb() assert result.map_value.fields.get("alias1") == arg._to_pb() @@ -386,8 +408,36 @@ def test_to_map(self): assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") + class TestAliasedAggregate: -class TestFilterCondition: + def test_repr(self): + instance = expr.Field.of("field1").maximum().as_("alias1") + assert repr(instance) == "Maximum(Field.of('field1')).as_('alias1')" + + def test_ctor(self): + arg = expr.Field.of("field1").minimum() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = expr.Field.of("field1").average() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + arg = expr.Field.of("field1").count() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == arg._to_pb() + + +class TestBooleanExpr: def test__from_query_filter_pb_composite_filter_or(self, mock_client): """ test composite OR filters @@ -415,16 +465,16 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + expr.Equal(expr.Field.of("field1"), expr.Constant("val1")), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Eq(expr.Field.of("field2"), expr.Constant(None)), + expr.Equal(expr.Field.of("field2"), expr.Constant(None)), ) expected = expr.Or(expected_cond1, expected_cond2) @@ -458,16 +508,16 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Gt(expr.Field.of("field1"), expr.Constant(100)), + expr.GreaterThan(expr.Field.of("field1"), expr.Constant(100)), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Lt(expr.Field.of("field2"), expr.Constant(200)), + expr.LessThan(expr.Field.of("field2"), expr.Constant(200)), ) expected = expr.And(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -509,19 +559,19 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): composite_filter=outer_or_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + expr.Equal(expr.Field.of("field1"), expr.Constant("val1")), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Gt(expr.Field.of("field2"), expr.Constant(10)), + expr.GreaterThan(expr.Field.of("field2"), expr.Constant(10)), ) expected_cond3 = expr.And( expr.Exists(expr.Field.of("field3")), - expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + expr.Not(expr.Equal(expr.Field.of("field3"), expr.Constant(None))), ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -546,7 +596,7 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ) with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, expected_expr_func", @@ -558,11 +608,11 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.eq(None), + lambda f: f.equal(None), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.eq(None)), + lambda f: expr.Not(f.equal(None)), ), ], ) @@ -579,7 +629,7 @@ def test__from_query_filter_pb_unary_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr_inst = expr.Field.of(field_path) expected_condition = expected_expr_func(field_expr_inst) @@ -600,25 +650,37 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, value, expected_expr_func", [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + 10, + expr.LessThan, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, - expr.Lte, + expr.LessThanOrEqual, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + 10, + expr.GreaterThan, ), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, - expr.Gte, + expr.GreaterThanOrEqual, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Equal), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, + 10, + expr.NotEqual, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, @@ -629,12 +691,8 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): [10, 20], expr.ArrayContainsAny, ), - (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), - ( - query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, - [10, 20], - lambda f, v: expr.Not(f.in_any(v)), - ), + (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.EqualAny), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], expr.NotEqualAny), ], ) def test__from_query_filter_pb_field_filter( @@ -652,7 +710,7 @@ def test__from_query_filter_pb_field_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr = expr.Field.of(field_path) # convert values into constants @@ -681,7 +739,7 @@ def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ @@ -689,12 +747,12 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ # Test with an unexpected protobuf type with pytest.raises(TypeError, match="Unexpected filter type"): - FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) -class TestFilterConditionClasses: +class TestBooleanExprClasses: """ - contains test methods for each Expr class that derives from FilterCondition + contains test methods for each Expr class that derives from BooleanExpr """ def _make_arg(self, name="Mock"): @@ -747,64 +805,75 @@ def test_exists(self): assert instance.params == [arg1] assert repr(instance) == "Field.exists()" - def test_eq(self): + def test_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Eq(arg1, arg2) - assert instance.name == "eq" + instance = expr.Equal(arg1, arg2) + assert instance.name == "equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.eq(Right)" + assert repr(instance) == "Left.equal(Right)" - def test_gte(self): + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gte(arg1, arg2) - assert instance.name == "gte" + instance = expr.GreaterThanOrEqual(arg1, arg2) + assert instance.name == "greater_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gte(Right)" + assert repr(instance) == "Left.greater_than_or_equal(Right)" - def test_gt(self): + def test_greater_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gt(arg1, arg2) - assert instance.name == "gt" + instance = expr.GreaterThan(arg1, arg2) + assert instance.name == "greater_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gt(Right)" + assert repr(instance) == "Left.greater_than(Right)" - def test_lte(self): + def test_less_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lte(arg1, arg2) - assert instance.name == "lte" + instance = expr.LessThanOrEqual(arg1, arg2) + assert instance.name == "less_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lte(Right)" + assert repr(instance) == "Left.less_than_or_equal(Right)" - def test_lt(self): + def test_less_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lt(arg1, arg2) - assert instance.name == "lt" + instance = expr.LessThan(arg1, arg2) + assert instance.name == "less_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lt(Right)" + assert repr(instance) == "Left.less_than(Right)" - def test_neq(self): + def test_not_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Neq(arg1, arg2) - assert instance.name == "neq" + instance = expr.NotEqual(arg1, arg2) + assert instance.name == "not_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.neq(Right)" + assert repr(instance) == "Left.not_equal(Right)" + + def test_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = expr.EqualAny(arg1, [arg2, arg3]) + assert instance.name == "equal_any" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.equal_any(ListOfExprs([Value1, Value2]))" - def test_in(self): + def test_not_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = expr.In(arg1, [arg2, arg3]) - assert instance.name == "in" + instance = expr.NotEqualAny(arg1, [arg2, arg3]) + assert instance.name == "not_equal_any" assert isinstance(instance.params[1], ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + assert repr(instance) == "Field.not_equal_any(ListOfExprs([Value1, Value2]))" def test_is_nan(self): arg1 = self._make_arg("Value") @@ -842,14 +911,14 @@ def test_ends_with(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.ends_with(Postfix)" - def test_if(self): + def test_conditional(self): arg1 = self._make_arg("Condition") - arg2 = self._make_arg("TrueExpr") - arg3 = self._make_arg("FalseExpr") - instance = expr.If(arg1, arg2, arg3) - assert instance.name == "if" + arg2 = self._make_arg("ThenExpr") + arg3 = self._make_arg("ElseExpr") + instance = expr.Conditional(arg1, arg2, arg3) + assert instance.name == "conditional" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + assert repr(instance) == "Conditional(Condition, ThenExpr, ElseExpr)" def test_like(self): arg1 = self._make_arg("Expr") @@ -883,13 +952,13 @@ def test_starts_with(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.starts_with(Prefix)" - def test_str_contains(self): + def test_string_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Substring") - instance = expr.StrContains(arg1, arg2) - assert instance.name == "str_contains" + instance = expr.StringContains(arg1, arg2) + assert instance.name == "string_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.str_contains(Substring)" + assert repr(instance) == "Expr.string_contains(Substring)" def test_xor(self): arg1 = self._make_arg("Condition1") @@ -908,21 +977,34 @@ class TestFunctionClasses: @pytest.mark.parametrize( "method,args,result_cls", [ + ("conditional", ("field", "then", "else"), expr.Conditional), ("add", ("field", 2), expr.Add), ("subtract", ("field", 2), expr.Subtract), ("multiply", ("field", 2), expr.Multiply), ("divide", ("field", 2), expr.Divide), ("mod", ("field", 2), expr.Mod), - ("logical_max", ("field", 2), expr.LogicalMax), - ("logical_min", ("field", 2), expr.LogicalMin), - ("eq", ("field", 2), expr.Eq), - ("neq", ("field", 2), expr.Neq), - ("lt", ("field", 2), expr.Lt), - ("lte", ("field", 2), expr.Lte), - ("gt", ("field", 2), expr.Gt), - ("gte", ("field", 2), expr.Gte), - ("in_any", ("field", [None]), expr.In), - ("not_in_any", ("field", [None]), expr.Not), + ("abs", ("field",), expr.Abs), + ("ceil", ("field",), expr.Ceil), + ("exp", ("field",), expr.Exp), + ("floor", ("field",), expr.Floor), + ("ln", ("field",), expr.Ln), + ("log", ("field", 10), expr.Log), + ("pow", ("field", 2), expr.Pow), + ("round", ("field",), expr.Round), + ("sqrt", ("field",), expr.Sqrt), + ("logical_maximum", ("field", 2), expr.LogicalMaximum), + ("logical_minimum", ("field", 2), expr.LogicalMinimum), + ("equal", ("field", 2), expr.Equal), + ("not_equal", ("field", 2), expr.NotEqual), + ("less_than", ("field", 2), expr.LessThan), + ("less_than_or_equal", ("field", 2), expr.LessThanOrEqual), + ("greater_than", ("field", 2), expr.GreaterThan), + ("greater_than_or_equal", ("field", 2), expr.GreaterThanOrEqual), + ("equal_any", ("field", [None]), expr.EqualAny), + ("not_equal_any", ("field", [None]), expr.NotEqualAny), + ("array", ([1, 2, 3],), expr.Array), + ("map", ({"hello": "world"},), expr.Map), + ("array_get", ("field", 2), expr.ArrayGet), ("array_contains", ("field", None), expr.ArrayContains), ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), @@ -931,21 +1013,22 @@ class TestFunctionClasses: ("is_nan", ("field",), expr.IsNaN), ("exists", ("field",), expr.Exists), ("sum", ("field",), expr.Sum), - ("avg", ("field",), expr.Avg), + ("average", ("field",), expr.Average), ("count", ("field",), expr.Count), - ("count", (), expr.Count), - ("min", ("field",), expr.Min), - ("max", ("field",), expr.Max), + ("minimum", ("field",), expr.Minimum), + ("maximum", ("field",), expr.Maximum), ("char_length", ("field",), expr.CharLength), ("byte_length", ("field",), expr.ByteLength), ("like", ("field", "pattern"), expr.Like), ("regex_contains", ("field", "regex"), expr.RegexContains), ("regex_matches", ("field", "regex"), expr.RegexMatch), - ("str_contains", ("field", "substring"), expr.StrContains), + ("string_contains", ("field", "substring"), expr.StringContains), ("starts_with", ("field", "prefix"), expr.StartsWith), ("ends_with", ("field", "postfix"), expr.EndsWith), - ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), + ("string_concat", ("field", "elem1", "elem2"), expr.StringConcat), ("map_get", ("field", "key"), expr.MapGet), + ("map_remove", ("field", "key"), expr.MapRemove), + ("map_merge", ("field", {"key": "value"}), expr.MapMerge), ("vector_length", ("field",), expr.VectorLength), ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), @@ -954,7 +1037,7 @@ class TestFunctionClasses: ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), - ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), + ("timestamp_subtract", ("field", "hour", 2.5), expr.TimestampSubtract), ], ) def test_function_builder(self, method, args, result_cls): @@ -969,14 +1052,18 @@ def test_function_builder(self, method, args, result_cls): @pytest.mark.parametrize( "first,second,expected", [ - (expr.ArrayElement(), expr.ArrayElement(), True), - (expr.ArrayElement(), expr.CharLength(1), False), - (expr.ArrayElement(), object(), False), - (expr.ArrayElement(), None, False), - (expr.CharLength(1), expr.ArrayElement(), False), + (expr.Array([]), expr.Array([]), True), + (expr.Array([]), expr.CharLength(1), False), + (expr.Array([]), object(), False), + (expr.Array([]), None, False), + (expr.CharLength(1), expr.Array([]), False), (expr.CharLength(1), expr.CharLength(2), False), (expr.CharLength(1), expr.CharLength(1), True), (expr.CharLength(1), expr.ByteLength(1), False), + (expr.Array([1]), expr.Array([1]), True), + (expr.Array([1]), expr.Array([2]), False), + (expr.Array([1]), expr.Array([]), False), + (expr.Array([1, 2]), expr.Array([1]), False), ], ) def test_equality(self, first, second, expected): @@ -995,21 +1082,29 @@ def test_divide(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Divide(Left, Right)" - def test_logical_max(self): + def test_logical_maximum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMax(arg1, arg2) - assert instance.name == "logical_maximum" + instance = expr.LogicalMaximum(arg1, arg2) + assert instance.name == "max" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMax(Left, Right)" + assert repr(instance) == "LogicalMaximum(Left, Right)" - def test_logical_min(self): + def test_logical_minimum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMin(arg1, arg2) - assert instance.name == "logical_minimum" + instance = expr.LogicalMinimum(arg1, arg2) + assert instance.name == "min" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMin(Left, Right)" + assert repr(instance) == "LogicalMinimum(Left, Right)" + + def test_map(self): + key = expr.Constant.of("key") + value = self._make_arg("value") + instance = expr.Map({"key": value}) + assert instance.name == "map" + assert instance.params == [key, value] + assert repr(instance) == "Map({'key': value})" def test_map_get(self): arg1 = self._make_arg("Map") @@ -1019,6 +1114,23 @@ def test_map_get(self): assert instance.params == [arg1, arg2] assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = expr.Constant("Key") + instance = expr.MapRemove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, arg2] + assert repr(instance) == "MapRemove(Map, Constant.of('Key'))" + + def test_map_merge(self): + arg1 = self._make_arg("Map1") + arg2 = self._make_arg("Map2") + arg3 = self._make_arg("Map3") + instance = expr.MapMerge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "MapMerge(Map1, Map2, Map3)" + def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") @@ -1042,13 +1154,13 @@ def test_parent(self): assert instance.params == [arg1] assert repr(instance) == "Parent(Value)" - def test_str_concat(self): + def test_string_concat(self): arg1 = self._make_arg("Str1") arg2 = self._make_arg("Str2") - instance = expr.StrConcat(arg1, arg2) - assert instance.name == "str_concat" + instance = expr.StringConcat(arg1, arg2) + assert instance.name == "string_concat" assert instance.params == [arg1, arg2] - assert repr(instance) == "StrConcat(Str1, Str2)" + assert repr(instance) == "StringConcat(Str1, Str2)" def test_subtract(self): arg1 = self._make_arg("Left") @@ -1067,14 +1179,14 @@ def test_timestamp_add(self): assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" - def test_timestamp_sub(self): + def test_timestamp_subtract(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampSub(arg1, arg2, arg3) - assert instance.name == "timestamp_sub" + instance = expr.TimestampSubtract(arg1, arg2, arg3) + assert instance.name == "timestamp_subtract" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + assert repr(instance) == "TimestampSubtract(Timestamp, Unit, Amount)" def test_timestamp_to_unix_micros(self): arg1 = self._make_arg("Input") @@ -1133,20 +1245,6 @@ def test_add(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Add(Left, Right)" - def test_array_element(self): - instance = expr.ArrayElement() - assert instance.name == "array_element" - assert instance.params == [] - assert repr(instance) == "ArrayElement()" - - def test_array_filter(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("FilterCond") - instance = expr.ArrayFilter(arg1, arg2) - assert instance.name == "array_filter" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayFilter(Array, FilterCond)" - def test_array_length(self): arg1 = self._make_arg("Array") instance = expr.ArrayLength(arg1) @@ -1161,14 +1259,6 @@ def test_array_reverse(self): assert instance.params == [arg1] assert repr(instance) == "ArrayReverse(Array)" - def test_array_transform(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("TransformFunc") - instance = expr.ArrayTransform(arg1, arg2) - assert instance.name == "array_transform" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayTransform(Array, TransformFunc)" - def test_byte_length(self): arg1 = self._make_arg("Expr") instance = expr.ByteLength(arg1) @@ -1197,12 +1287,12 @@ def test_sum(self): assert instance.params == [arg1] assert repr(instance) == "Sum(Value)" - def test_avg(self): + def test_average(self): arg1 = self._make_arg("Value") - instance = expr.Avg(arg1) - assert instance.name == "avg" + instance = expr.Average(arg1) + assert instance.name == "average" assert instance.params == [arg1] - assert repr(instance) == "Avg(Value)" + assert repr(instance) == "Average(Value)" def test_count(self): arg1 = self._make_arg("Value") @@ -1216,16 +1306,157 @@ def test_count_empty(self): assert instance.params == [] assert repr(instance) == "Count()" - def test_min(self): + def test_minimum(self): + arg1 = self._make_arg("Value") + instance = expr.Minimum(arg1) + assert instance.name == "min" + assert instance.params == [arg1] + assert repr(instance) == "Minimum(Value)" + + def test_maximum(self): + arg1 = self._make_arg("Value") + instance = expr.Maximum(arg1) + assert instance.name == "max" + assert instance.params == [arg1] + assert repr(instance) == "Maximum(Value)" + + def test_dot_product(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.DotProduct(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "DotProduct(Left, Right)" + + def test_euclidean_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.EuclideanDistance(arg1, arg2) + assert instance.name == "euclidean_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "EuclideanDistance(Left, Right)" + + def test_cosine_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.CosineDistance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "CosineDistance(Left, Right)" + + def test_reverse(self): + arg1 = self._make_arg("Expr") + instance = expr.Reverse(arg1) + assert instance.name == "reverse" + assert instance.params == [arg1] + assert repr(instance) == "Reverse(Expr)" + + def test_to_lower(self): + arg1 = self._make_arg("Expr") + instance = expr.ToLower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "ToLower(Expr)" + + def test_to_upper(self): + arg1 = self._make_arg("Expr") + instance = expr.ToUpper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "ToUpper(Expr)" + + def test_trim(self): + arg1 = self._make_arg("Expr") + instance = expr.Trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Trim(Expr)" + + def test_array(self): + arg = self._make_arg("Value") + instance = expr.Array([1, 2, arg]) + assert instance.name == "array" + assert instance.params == [1, 2, arg] + assert repr(instance) == "Array([1, 2, Value])" + + def test_array_get(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Index") + instance = expr.ArrayGet(arg1, arg2) + assert instance.name == "array_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayGet(Array, Index)" + + def test_array_concat(self): + arg1 = self._make_arg("1") + arg2 = self._make_arg("2") + arg3 = self._make_arg("3") + instance = expr.ArrayConcat(arg1, [arg2, arg3]) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "ArrayConcat(1, 2, 3)" + + def test_abs(self): + arg1 = self._make_arg("Value") + instance = expr.Abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Abs(Value)" + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = expr.Ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Ceil(Value)" + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = expr.Exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Exp(Value)" + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = expr.Floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Floor(Value)" + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = expr.Ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Ln(Value)" + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = expr.Log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Log(Value, Base)" + + def test_pow(self): + arg1 = self._make_arg("Base") + arg2 = self._make_arg("Exponent") + instance = expr.Pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Pow(Base, Exponent)" + + def test_round(self): arg1 = self._make_arg("Value") - instance = expr.Min(arg1) - assert instance.name == "minimum" + instance = expr.Round(arg1) + assert instance.name == "round" assert instance.params == [arg1] - assert repr(instance) == "Min(Value)" + assert repr(instance) == "Round(Value)" - def test_max(self): + def test_sqrt(self): arg1 = self._make_arg("Value") - instance = expr.Max(arg1) - assert instance.name == "maximum" + instance = expr.Sqrt(arg1) + assert instance.name == "sqrt" assert instance.params == [arg1] - assert repr(instance) == "Max(Value)" + assert repr(instance) == "Sqrt(Value)" diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index e67a4ca3a..941c4668a 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -80,7 +80,7 @@ def _make_one(self, *args, **kwargs): def test_ctor_positional(self): """test with only positional arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") instance = self._make_one(sum_total, avg_price) assert list(instance.accumulators) == [sum_total, avg_price] assert len(instance.groups) == 0 @@ -89,7 +89,7 @@ def test_ctor_positional(self): def test_ctor_keyword(self): """test with only keyword arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") group_category = Field.of("category") instance = self._make_one( accumulators=[avg_price, sum_total], groups=[group_category, "city"] @@ -104,7 +104,7 @@ def test_ctor_keyword(self): def test_ctor_combined(self): """test with a mix of arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") count = Count(Field.of("total")).as_("count") with pytest.raises(ValueError): self._make_one(sum_total, accumulators=[avg_price, count]) @@ -552,6 +552,56 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestReplace: + def _make_one(self, *args, **kwargs): + return stages.Replace(*args, **kwargs) + + def test_ctor_default(self): + instance = self._make_one("field") + assert isinstance(instance.field, Field) + assert instance.field.path == "field" + # default mode is FULL_REPLACE + assert instance.mode == stages.Replace.Mode.FULL_REPLACE + + @pytest.mark.parametrize( + "mode_str,expected_mode", + [ + ("full_replace", stages.Replace.Mode.FULL_REPLACE), + ("merge_prefer_next", stages.Replace.Mode.MERGE_PREFER_NEXT), + ("merge_prefer_parent", stages.Replace.Mode.MERGE_PREFER_PARENT), + ], + ) + def test_ctor_str_mode(self, mode_str, expected_mode): + instance = self._make_one("field", mode_str) + assert instance.mode == expected_mode + assert ( + repr(instance) + == f"Replace(field=Field.of('field'), mode=Replace.Mode.{mode_str.upper()})" + ) + + def test_ctor_w_field(self): + field = Field.of("field") + instance = self._make_one(field) + assert isinstance(instance.field, Field) + assert instance.field == field + + def test_repr(self): + instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) + repr_str = repr(instance) + assert ( + repr_str + == "Replace(field=Field.of('field'), mode=Replace.Mode.MERGE_PREFER_NEXT)" + ) + + def test_to_pb(self): + instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) + result = instance._to_pb() + assert result.name == "replace" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "field" + assert result.args[1].string_value == "merge_prefer_next" + + class TestSample: class TestSampleOptions: def test_ctor_percent(self): @@ -790,19 +840,21 @@ def _make_one(self, *args, **kwargs): return stages.Where(*args, **kwargs) def test_repr(self): - condition = Field.of("age").gt(30) + condition = Field.of("age").greater_than(30) instance = self._make_one(condition) repr_str = repr(instance) - assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + assert ( + repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + ) def test_to_pb(self): - condition = Field.of("city").eq("SF") + condition = Field.of("city").equal("SF") instance = self._make_one(condition) result = instance._to_pb() assert result.name == "where" assert len(result.args) == 1 got_fn = result.args[0].function_value - assert got_fn.name == "eq" + assert got_fn.name == "equal" assert len(got_fn.args) == 2 assert got_fn.args[0].field_reference_value == "city" assert got_fn.args[1].string_value == "SF"