diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index f175b5394..6b017d88e 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -13,15 +13,20 @@ # limitations under the License. from __future__ import annotations -from typing import AsyncIterable, TYPE_CHECKING +from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: # pragma: NO COVER import datetime from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class AsyncPipeline(_BasePipeline): @@ -41,7 +46,7 @@ class AsyncPipeline(_BasePipeline): ... .collection("books") ... .where(Field.of("published").gt(1980)) ... .select("title", "author") - ... async for result in pipeline.execute(): + ... async for result in pipeline.stream(): ... print(result) Use `client.pipeline()` to create instances of this class. @@ -59,15 +64,18 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): async def execute( self, + *, transaction: "AsyncTransaction" | None = None, read_time: datetime.datetime | None = None, - ) -> list[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: """ Executes this pipeline and returns results as a list Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -76,25 +84,33 @@ async def execute( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - return [ - result - async for result in self.stream( - transaction=transaction, read_time=read_time - ) - ] + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = AsyncPipelineStream(PipelineResult, self, **kwargs) + results = [result async for result in stream] + return PipelineSnapshot(results, stream) - async def stream( + def stream( self, - transaction: "AsyncTransaction" | None = None, + *, read_time: datetime.datetime | None = None, - ) -> AsyncIterable[PipelineResult]: + transaction: "AsyncTransaction" | None = None, + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> AsyncPipelineStream[PipelineResult]: """ - Process this pipeline as a stream, providing results through an Iterable + Process this pipeline as a stream, providing results through an AsyncIterable Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -103,10 +119,13 @@ async def stream( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - request = self._prep_execute_request(transaction, read_time) - async for response in await self._client._firestore_api.execute_pipeline( - request - ): - for result in self._execute_response_helper(response): - yield result + kwargs = {k: v for k, v in locals().items() if k != "self"} + return AsyncPipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 7f52c2021..153564663 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,15 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Iterable, Sequence, TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( AggregateFunction, AliasedExpression, @@ -30,14 +28,10 @@ BooleanExpression, Selectable, ) -from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER - import datetime from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse - from google.cloud.firestore_v1.transaction import BaseTransaction class _BasePipeline: @@ -88,9 +82,10 @@ def __repr__(self): stages_str = ",\n ".join([repr(s) for s in self.stages]) return f"{cls_str}(\n {stages_str}\n)" - def _to_pb(self) -> StructuredPipeline_pb: + def _to_pb(self, **options) -> StructuredPipeline_pb: return StructuredPipeline_pb( - pipeline={"stages": [s._to_pb() for s in self.stages]} + pipeline={"stages": [s._to_pb() for s in self.stages]}, + options=options, ) def _append(self, new_stage): @@ -99,47 +94,6 @@ def _append(self, new_stage): """ return self.__class__._create_with_stages(self._client, *self.stages, new_stage) - def _prep_execute_request( - self, - transaction: BaseTransaction | None, - read_time: datetime.datetime | None, - ) -> ExecutePipelineRequest: - """ - shared logic for creating an ExecutePipelineRequest - """ - database_name = ( - f"projects/{self._client.project}/databases/{self._client._database}" - ) - transaction_id = ( - _helpers.get_transaction_id(transaction) - if transaction is not None - else None - ) - request = ExecutePipelineRequest( - database=database_name, - transaction=transaction_id, - structured_pipeline=self._to_pb(), - read_time=read_time, - ) - return request - - def _execute_response_helper( - self, response: ExecutePipelineResponse - ) -> Iterable[PipelineResult]: - """ - shared logic for unpacking an ExecutePipelineReponse into PipelineResults - """ - for doc in response.results: - ref = self._client.document(doc.name) if doc.name else None - yield PipelineResult( - self._client, - doc.fields, - ref, - response._pb.execution_time, - doc._pb.create_time if doc.create_time else None, - doc._pb.update_time if doc.update_time else None, - ) - def add_fields(self, *fields: Selectable) -> "_BasePipeline": """ Adds new fields to outputs from previous stages. diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index b4567189b..950eb6ffa 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,15 +13,20 @@ # limitations under the License. from __future__ import annotations -from typing import Iterable, TYPE_CHECKING +from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: # pragma: NO COVER import datetime from google.cloud.firestore_v1.client import Client - from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class Pipeline(_BasePipeline): @@ -56,15 +61,18 @@ def __init__(self, client: Client, *stages: stages.Stage): def execute( self, + *, transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, - ) -> list[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: """ Executes this pipeline and returns results as a list Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -73,23 +81,33 @@ def execute( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - return [ - result - for result in self.stream(transaction=transaction, read_time=read_time) - ] + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = PipelineStream(PipelineResult, self, **kwargs) + results = [result for result in stream] + return PipelineSnapshot(results, stream) def stream( self, + *, transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, - ) -> Iterable[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineStream[PipelineResult]: """ Process this pipeline as a stream, providing results through an Iterable Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -98,7 +116,13 @@ def stream( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - request = self._prep_execute_request(transaction, read_time) - for response in self._client._firestore_api.execute_pipeline(request): - yield from self._execute_response_helper(response) + kwargs = {k: v for k, v in locals().items() if k != "self"} + return PipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index ada855fea..6be08fa57 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -13,17 +13,43 @@ # limitations under the License. from __future__ import annotations -from typing import Any, MutableMapping, TYPE_CHECKING +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Generic, + MutableMapping, + Type, + TypeVar, + TYPE_CHECKING, +) from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import get_nested_value from google.cloud.firestore_v1.field_path import FieldPath +from google.cloud.firestore_v1.query_profile import ExplainStats +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.types.document import Value if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse from google.cloud.firestore_v1.types.document import Value as ValueProto from google.cloud.firestore_v1.vector import Vector + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class PipelineResult: @@ -137,3 +163,138 @@ def get(self, field_path: str | FieldPath) -> Any: ) value = get_nested_value(str_path, self._fields_pb) return _helpers.decode_value(value, self._client) + + +T = TypeVar("T", bound=PipelineResult) + + +class _PipelineResultContainer(Generic[T]): + """Base class to hold shared attributes for PipelineSnapshot and PipelineStream""" + + def __init__( + self, + return_type: Type[T], + pipeline: Pipeline | AsyncPipeline, + transaction: Transaction | AsyncTransaction | None, + read_time: datetime.datetime | None, + explain_options: PipelineExplainOptions | None, + index_mode: str | None, + additional_options: dict[str, Constant | Value], + ): + # public + self.transaction = transaction + self.pipeline: _BasePipeline = pipeline + self.execution_time: Timestamp | None = None + # private + self._client: Client | AsyncClient = pipeline._client + self._started: bool = False + self._read_time = read_time + self._explain_stats: ExplainStats | None = None + self._explain_options: PipelineExplainOptions | None = explain_options + self._return_type = return_type + self._index_mode = index_mode + self._additonal_options = { + k: v if isinstance(v, Value) else v._to_pb() + for k, v in additional_options.items() + } + + @property + def explain_stats(self) -> ExplainStats: + if self._explain_stats is not None: + return self._explain_stats + elif self._explain_options is None: + raise QueryExplainError("explain_options not set on query.") + elif not self._started: + raise QueryExplainError( + "explain_stats not available until query is complete" + ) + else: + raise QueryExplainError("explain_stats not found") + + def _build_request(self) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(self.transaction, read_operation=False) + if self.transaction is not None + else None + ) + options = {} + if self._explain_options: + options["explain_options"] = self._explain_options._to_value() + if self._index_mode: + options["index_mode"] = Value(string_value=self._index_mode) + if self._additonal_options: + options.update(self._additonal_options) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self.pipeline._to_pb(**options), + read_time=self._read_time, + ) + return request + + def _process_response(self, response: ExecutePipelineResponse) -> Iterable[T]: + """Shared logic for processing an individual response from a stream""" + if response.explain_stats: + self._explain_stats = ExplainStats(response.explain_stats) + execution_time = response._pb.execution_time + if execution_time and not self.execution_time: + self.execution_time = execution_time + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield self._return_type( + self._client, + doc.fields, + ref, + execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + +class PipelineSnapshot(_PipelineResultContainer[T], list[T]): + """ + A list type that holds the result of a pipeline.execute() operation, along with related metadata + """ + + def __init__(self, results_list: list[T], source: _PipelineResultContainer[T]): + self.__dict__.update(source.__dict__.copy()) + list.__init__(self, results_list) + # snapshots are always complete + self._started = True + + +class PipelineStream(_PipelineResultContainer[T], Iterable[T]): + """ + An iterable stream representing the result of a pipeline.stream() operation, along with related metadata + """ + + def __iter__(self) -> Iterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = self._client._firestore_api.execute_pipeline(request) + for response in stream: + yield from self._process_response(response) + + +class AsyncPipelineStream(_PipelineResultContainer[T], AsyncIterable[T]): + """ + An iterable stream representing the result of an async pipeline.stream() operation, along with related metadata + """ + + async def __aiter__(self) -> AsyncIterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = await self._client._firestore_api.execute_pipeline(request) + async for response in stream: + for result in self._process_response(response): + yield result diff --git a/google/cloud/firestore_v1/query_profile.py b/google/cloud/firestore_v1/query_profile.py index 6925f83ff..5e8491fc6 100644 --- a/google/cloud/firestore_v1/query_profile.py +++ b/google/cloud/firestore_v1/query_profile.py @@ -19,6 +19,12 @@ from dataclasses import dataclass from google.protobuf.json_format import MessageToDict +from google.cloud.firestore_v1.types.document import MapValue +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, +) +from google.protobuf.wrappers_pb2 import StringValue @dataclass(frozen=True) @@ -42,6 +48,32 @@ def _to_dict(self): return {"analyze": self.analyze} +@dataclass(frozen=True) +class PipelineExplainOptions: + """ + Explain options for pipeline queries. + + Set on a pipeline.execution() or pipeline.stream() call, to provide + explain_stats in the pipeline output + + :type mode: str + :param mode: Optional. The mode of operation for this explain query. + When set to 'analyze', the query will be executed and return the full + query results along with execution statistics. + + :type output_format: str | None + :param output_format: Optional. The format in which to return the explain + stats. + """ + + mode: str = "analyze" + + def _to_value(self): + out_dict = {"mode": Value(string_value=self.mode)} + value_pb = MapValue(fields=out_dict) + return Value(map_value=value_pb) + + @dataclass(frozen=True) class PlanSummary: """ @@ -143,3 +175,54 @@ class QueryExplainError(Exception): """ pass + + +class ExplainStats: + """ + Contains query profiling statistics for a pipeline query. + + This class is not meant to be instantiated directly by the user. Instead, an + instance of `ExplainStats` may be returned by pipeline execution methods + when `explain_options` are provided. + + It provides methods to access the explain statistics in different formats. + """ + + def __init__(self, stats_pb: ExplainStats_pb): + """ + Args: + stats_pb (ExplainStats_pb): The raw protobuf message for explain stats. + """ + self._stats_pb = stats_pb + + def get_text(self) -> str: + """ + Returns the explain stats as a string. + + This method is suitable for explain formats that have a text-based output, + such as 'text' or 'json'. + + Returns: + str: The string representation of the explain stats. + + Raises: + QueryExplainError: If the explain stats payload from the backend is not + a string. This can happen if a non-text output format was requested. + """ + pb_data = self._stats_pb._pb.data + content = StringValue() + if pb_data.Unpack(content): + return content.value + raise QueryExplainError( + "Unable to decode explain stats. Did you request an output format that returns a string value, such as 'text' or 'json'?" + ) + + def get_raw(self) -> ExplainStats_pb: + """ + Returns the explain stats in an encoded proto format, as returned from the Firestore backend. + The caller is responsible for unpacking this proto message. + + Returns: + google.cloud.firestore_v1.types.explain_stats.ExplainStats: the proto from the backend + """ + return self._stats_pb diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 61b1a983c..615ff1226 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -21,6 +21,7 @@ import google.auth import pytest +import mock from google.api_core.exceptions import ( AlreadyExists, FailedPrecondition, @@ -1652,6 +1653,140 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) + + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + # for now, expect error on explain mode + with pytest.raises(InvalidArgument) as e: + results = method_under_test(explain_options=explain_options) + list(results) + assert "Explain execution mode is not supported" in str(e) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + results = method_under_test(explain_options=PipelineExplainOptions()) + + if method == "stream": + # check for error accessing explain stats before iterating + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + # Finish iterating results, and explain_stats should be available. + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + results = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + + # Finish iterating results, and explain_stats should be available./w_read + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_index_mode(database, query_docs): + """test pipeline query with explicit index mode""" + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + with pytest.raises(InvalidArgument) as e: + pipeline.execute(index_mode="fake_index") + assert "Invalid index_mode: fake_index" in str(e) + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs @@ -1703,15 +1838,16 @@ def test_pipeline_w_read_time(query_docs, cleanup, database): new_data = { "a": 9000, "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, } _, new_ref = collection.add(new_data) # Add to clean-up. cleanup(new_ref.delete) stored[new_ref.id] = new_data - pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) + # new query should have new_data new_results = list(pipeline.stream()) new_values = {result.ref.id: result.data() for result in new_results} diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 99b9da801..373c40118 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -22,6 +22,7 @@ import google.auth import pytest import pytest_asyncio +import mock from google.api_core import exceptions as core_exceptions from google.api_core import retry_async as retries from google.api_core.exceptions import ( @@ -1573,44 +1574,124 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -async def test_query_stream_w_read_time(query_docs, cleanup, database): - collection, stored, allowed_vals = query_docs - num_vals = len(allowed_vals) +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.api_core.exceptions import InvalidArgument + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) - # Find the most recent read_time in collections - read_time = max( - [(await docref.get()).read_time async for docref in collection.list_documents()] + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + with pytest.raises(InvalidArgument) as e: + if method == "stream": + results = method_under_test(explain_options=explain_options) + _ = [i async for i in results] + else: + await method_under_test(explain_options=explain_options) + + assert "Explain execution mode is not supported" in str(e.value) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, ) - new_data = { - "a": 9000, - "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, - } - _, new_ref = await collection.add(new_data) - # Add to clean-up. - cleanup(new_ref.delete) - stored[new_ref.id] = new_data - # Compare query at read_time to query at current time. - query = collection.where(filter=FieldFilter("b", "==", 1)) - values = { - snapshot.id: snapshot.to_dict() - async for snapshot in query.stream(read_time=read_time) - } - assert len(values) == num_vals - assert new_ref.id not in values - for key, value in values.items(): - assert stored[key] == value - assert value["b"] == 1 - assert value["a"] != 9000 - assert key != new_ref + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) - new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} - assert len(new_values) == num_vals + 1 - assert new_ref.id in new_values - assert new_values[new_ref.id] == new_data + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions() + + if method == "execute": + results = await method_under_test(explain_options=explain_options) + num_results = len(results) + else: + results = method_under_test(explain_options=explain_options) + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + num_results = len([item async for item in results]) + + explain_stats = results.explain_stats + + assert num_results == len(allowed_vals) + + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + stub = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + if method == "execute": + results = await stub + num_results = len(results) + else: + results = stub + num_results = len([item async for item in results]) + + assert num_results == len(allowed_vals) + + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats @pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") @@ -1626,15 +1707,14 @@ async def test_pipeline_w_read_time(query_docs, cleanup, database): new_data = { "a": 9000, "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, } _, new_ref = await collection.add(new_data) # Add to clean-up. cleanup(new_ref.delete) stored[new_ref.id] = new_data - - pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) # new query should have new_data new_results = [result async for result in pipeline.stream()] @@ -1657,6 +1737,46 @@ async def test_pipeline_w_read_time(query_docs, cleanup, database): assert key != new_ref.id +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +async def test_query_stream_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + # Compare query at read_time to query at current time. + query = collection.where(filter=FieldFilter("b", "==", 1)) + values = { + snapshot.id: snapshot.to_dict() + async for snapshot in query.stream(read_time=read_time) + } + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref + + new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 189b24fba..5a7fb360c 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -389,30 +389,6 @@ async def test_async_pipeline_stream_stream_equivalence(): assert stream_results[0].data()["key"] == "str_val" -@pytest.mark.asyncio -async def test_async_pipeline_stream_stream_equivalence_mocked(): - """ - pipeline.stream should call pipeline.stream internally - """ - import datetime - - ppl_1 = _make_async_pipeline() - expected_data = [object(), object()] - expected_transaction = object() - expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) - with mock.patch.object(ppl_1, "stream") as mock_stream: - mock_stream.return_value = _async_it(expected_data) - stream_results = await ppl_1.execute( - transaction=expected_transaction, read_time=expected_read_time - ) - assert mock_stream.call_count == 1 - assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 2 - assert mock_stream.call_args[1]["transaction"] == expected_transaction - assert mock_stream.call_args[1]["read_time"] == expected_read_time - assert stream_results == expected_data - - @pytest.mark.parametrize( "method,args,result_cls", [ diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 34d3400e8..fc8e90a04 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -96,6 +96,17 @@ def test_pipeline__to_pb(): assert pb.pipeline.stages[1] == stage_2._to_pb() +def test_pipeline__to_pb_with_options(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + from google.cloud.firestore_v1.types.document import Value + + ppl = _make_pipeline() + options = {"option_1": Value(integer_value=1)} + pb = ppl._to_pb(**options) + assert isinstance(pb, StructuredPipeline) + assert pb.options["option_1"].integer_value == 1 + + def test_pipeline_append(): """append should create a new pipeline with the additional stage""" @@ -365,29 +376,6 @@ def test_pipeline_execute_stream_equivalence(): assert execute_results[0].data()["key"] == "str_val" -def test_pipeline_execute_stream_equivalence_mocked(): - """ - pipeline.execute should call pipeline.stream internally - """ - import datetime - - ppl_1 = _make_pipeline() - expected_data = [object(), object()] - expected_transaction = object() - expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) - with mock.patch.object(ppl_1, "stream") as mock_stream: - mock_stream.return_value = expected_data - stream_results = ppl_1.execute( - transaction=expected_transaction, read_time=expected_read_time - ) - assert mock_stream.call_count == 1 - assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 2 - assert mock_stream.call_args[1]["transaction"] == expected_transaction - assert mock_stream.call_args[1]["read_time"] == expected_read_time - assert stream_results == expected_data - - @pytest.mark.parametrize( "method,args,result_cls", [ diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index 2facf7110..579992741 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -15,7 +15,30 @@ import mock import pytest +from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.query_profile import PipelineExplainOptions +from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1.types.document import Document +from google.protobuf.timestamp_pb2 import Timestamp + + +_mock_stream_responses = [ + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d1", fields={})], + execution_time=Timestamp(seconds=1, nanos=2), + explain_stats={"data": {}}, + ), + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d2", fields={})], + execution_time=Timestamp(seconds=3, nanos=4), + ), +] class TestPipelineResult: @@ -174,3 +197,314 @@ def test_get_call(self): got = instance.get("key") decode_mock.assert_called_once_with("value", client) assert got == decode_mock.return_value + + +class TestPipelineSnapshot: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [[], mock.Mock()] + return PipelineSnapshot(*args, **kwargs) + + def test_ctor(self): + in_arr = [1, 2, 3] + expected_type = object() + expected_pipeline = mock.Mock() + expected_transaction = object() + expected_read_time = 123 + expected_explain_options = object() + expected_index_mode = "mode" + expected_addtl_options = {} + source = PipelineStream( + expected_type, + expected_pipeline, + expected_transaction, + expected_read_time, + expected_explain_options, + expected_index_mode, + expected_addtl_options, + ) + instance = self._make_one(in_arr, source) + assert instance._return_type == expected_type + assert instance.pipeline == expected_pipeline + assert instance._client == expected_pipeline._client + assert instance._additonal_options == expected_addtl_options + assert instance._index_mode == expected_index_mode + assert instance._explain_options == expected_explain_options + assert instance._explain_stats is None + assert instance._started is True + assert instance.execution_time is None + assert instance.transaction == expected_transaction + assert instance._read_time == expected_read_time + + def test_list_methods(self): + instance = self._make_one(list(range(10)), mock.Mock()) + assert isinstance(instance, list) + assert len(instance) == 10 + assert instance[0] == 0 + assert instance[-1] == 9 + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + + +class SharedStreamTests: + """ + Shared test logic for PipelineStream and AsyncPipelineStream + """ + + def _make_one(self, *args, **kwargs): + raise NotImplementedError + + def _mock_init_args(self): + # return default mocks for all init args + from google.cloud.firestore_v1.pipeline import Pipeline + + return { + "return_type": PipelineResult, + "pipeline": Pipeline(mock.Mock()), + "transaction": None, + "read_time": None, + "explain_options": None, + "index_mode": None, + "additional_options": {}, + } + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + @pytest.mark.parametrize( + "init_kwargs,expected_options", + [ + ({"index_mode": "mode"}, {"index_mode": encode_value("mode")}), + ( + {"explain_options": PipelineExplainOptions()}, + {"explain_options": encode_value({"mode": "analyze"})}, + ), + ( + {"explain_options": PipelineExplainOptions(mode="explain")}, + {"explain_options": encode_value({"mode": "explain"})}, + ), + ( + {"additional_options": {"explain_options": Constant("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + {"additional_options": {"explain_options": encode_value("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + { + "explain_options": PipelineExplainOptions(), + "additional_options": {"explain_options": Constant.of("override")}, + }, + {"explain_options": encode_value("override")}, + ), + ( + { + "index_mode": "mode", + "additional_options": {"index_mode": Constant("new")}, + }, + {"index_mode": encode_value("new")}, + ), + ], + ) + def test_build_request_options(self, init_kwargs, expected_options): + """ + Certain Arguments to PipelineStream should be passed to `options` field in proto request + """ + instance = self._make_one(**init_kwargs) + request = instance._build_request() + options = dict(request.structured_pipeline.options) + assert options == expected_options + assert len(options) == len(expected_options) + + def test_build_request_transaction(self): + """Ensure transaction is passed down when building request""" + from google.cloud.firestore_v1.transaction import Transaction + + expected_id = b"expected" + transaction = Transaction(mock.Mock()) + transaction._id = expected_id + instance = self._make_one(transaction=transaction) + request = instance._build_request() + assert request.transaction == expected_id + + def test_build_request_read_time(self): + """Ensure readtime is passed down when building request""" + import datetime + + ts = datetime.datetime.now() + instance = self._make_one(read_time=ts) + request = instance._build_request() + assert request.read_time.timestamp() == ts.timestamp() + + +class TestPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return PipelineStream(**init_kwargs) + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + def test_iter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + instance._client._firestore_api.execute_pipeline.return_value = ( + _mock_stream_responses + ) + + results = list(instance) + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + def test_double_iterate(self): + instance = self._make_one() + instance._client._firestore_api.execute_pipeline.return_value = [] + # consume the iterator + list(instance) + with pytest.raises(RuntimeError): + list(instance) + + +class TestAsyncPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return AsyncPipelineStream(**init_kwargs) + + @pytest.mark.asyncio + async def test_aiter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + async def async_gen(items): + for item in items: + yield item + + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen(_mock_stream_responses) + ) + + results = [item async for item in instance] + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + @pytest.mark.asyncio + async def test_double_iterate(self): + instance = self._make_one() + + async def async_gen(items): + for item in items: + yield item + + # mock the api call to avoid real network requests + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen([]) + ) + + # consume the iterator + [item async for item in instance] + # should fail on second attempt + with pytest.raises(RuntimeError): + [item async for item in instance] diff --git a/tests/unit/v1/test_query_profile.py b/tests/unit/v1/test_query_profile.py index a3b0390c6..5b1e470b8 100644 --- a/tests/unit/v1/test_query_profile.py +++ b/tests/unit/v1/test_query_profile.py @@ -124,3 +124,64 @@ def test_explain_options__to_dict(): assert ExplainOptions(analyze=True)._to_dict() == {"analyze": True} assert ExplainOptions(analyze=False)._to_dict() == {"analyze": False} + + +@pytest.mark.parametrize("mode_str", ["analyze", "explain"]) +def test_pipeline_explain_options__to_value(mode_str): + """ + Should be able to create a Value protobuf representation of ExplainOptions + """ + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + from google.cloud.firestore_v1.types.document import MapValue + from google.cloud.firestore_v1.types.document import Value + + options = PipelineExplainOptions(mode=mode_str) + expected_value = Value( + map_value=MapValue(fields={"mode": Value(string_value=mode_str)}) + ) + assert options._to_value() == expected_value + + +def test_explain_stats_get_raw(): + """ + Test ExplainStats.get_raw(). Should return input directly + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + + input = object() + stats = ExplainStats(input) + assert stats.get_raw() is input + + +def test_explain_stats_get_text(): + """ + Test ExplainStats.get_text() + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + from google.protobuf import any_pb2 + from google.protobuf import wrappers_pb2 + + expected_text = "some text" + text_pb = any_pb2.Any() + text_pb.Pack(wrappers_pb2.StringValue(value=expected_text)) + expected_stats_pb = explain_stats_pb2.ExplainStats(data=text_pb) + stats = ExplainStats(expected_stats_pb) + assert stats.get_text() == expected_text + + +def test_explain_stats_get_text_error(): + """ + Test ExplainStats.get_text() raises QueryExplainError + """ + from google.cloud.firestore_v1.query_profile import ( + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + + expected_stats_pb = explain_stats_pb2.ExplainStats(data={}) + stats = ExplainStats(expected_stats_pb) + with pytest.raises(QueryExplainError) as exc: + stats.get_text() + assert "Unable to decode explain stats" in str(exc.value)