From 2de16209409c9d9ba41d3444400e6a39ee1b2936 Mon Sep 17 00:00:00 2001 From: Jing Date: Wed, 10 Jul 2024 17:07:06 -0700 Subject: [PATCH] feat: support async Vector Search (#901) --- google/cloud/firestore_v1/async_query.py | 31 +++ .../cloud/firestore_v1/async_vector_query.py | 127 +++++++++ google/cloud/firestore_v1/base_query.py | 2 +- .../cloud/firestore_v1/base_vector_query.py | 10 +- tests/system/test_system_async.py | 43 ++++ tests/unit/v1/test_async_vector_query.py | 241 ++++++++++++++++++ 6 files changed, 452 insertions(+), 2 deletions(-) create mode 100644 google/cloud/firestore_v1/async_vector_query.py create mode 100644 tests/unit/v1/test_async_vector_query.py diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index c73e16724e..7a17eee47a 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -35,6 +35,7 @@ from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery +from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1 import transaction from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING @@ -42,7 +43,9 @@ if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.vector import Vector class AsyncQuery(BaseQuery): @@ -222,6 +225,34 @@ async def get( return result + def find_nearest( + self, + vector_field: str, + query_vector: Vector, + limit: int, + distance_measure: DistanceMeasure, + ) -> AsyncVectorQuery: + """ + Finds the closest vector embeddings to the given query vector. + + Args: + vector_field(str): An indexed vector field to search upon. Only documents which contain + vectors whose dimensionality match the query_vector can be returned. + query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + than 2048 dimensions. + limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. + distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + + Returns: + :class`~firestore_v1.vector_query.VectorQuery`: the vector query. + """ + return AsyncVectorQuery(self).find_nearest( + vector_field=vector_field, + query_vector=query_vector, + limit=limit, + distance_measure=distance_measure, + ) + def count( self, alias: str | None = None ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py new file mode 100644 index 0000000000..27de5251ca --- /dev/null +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -0,0 +1,127 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import ( + BaseQuery, + _query_response_to_snapshot, + _collection_group_query_response_to_snapshot, +) +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery +from typing import AsyncGenerator, List, Union, Optional, TypeVar + +TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery") + + +class AsyncVectorQuery(BaseVectorQuery): + """Represents an async vector query to the Firestore API.""" + + def __init__( + self, + nested_query: Union[BaseQuery, TAsyncVectorQuery], + ) -> None: + """Presents the vector query. + Args: + nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter. + """ + super(AsyncVectorQuery, self).__init__(nested_query) + + async def get( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> List[DocumentSnapshot]: + """Runs the vector query. + + This sends a ``RunQuery`` RPC and returns a list of document messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The vector query results. + """ + stream_result = self.stream( + transaction=transaction, retry=retry, timeout=timeout + ) + result = [snapshot async for snapshot in stream_result] + return result # type: ignore + + async def stream( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: + """Reads the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + request, expected_prefix, kwargs = self._prep_stream( + transaction, + retry, + timeout, + ) + + response_iterator = await self._client._firestore_api.run_query( + request=request, + metadata=self._client._rpc_metadata, + **kwargs, + ) + + async for response in response_iterator: + if self._nested_query._all_descendants: + snapshot = _collection_group_query_response_to_snapshot( + response, self._nested_query._parent + ) + else: + snapshot = _query_response_to_snapshot( + response, self._nested_query._parent, expected_prefix + ) + if snapshot is not None: + yield snapshot diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c8c2f3ceb2..9e75514a56 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -978,7 +978,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - queryVector: Vector, + query_vector: Vector, limit: int, distance_measure: DistanceMeasure, ) -> BaseVectorQuery: diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index e41717d2b5..cb9c00b3af 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -25,7 +25,7 @@ from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import document, _helpers class DistanceMeasure(Enum): @@ -117,3 +117,11 @@ def find_nearest( self._limit = limit self._distance_measure = distance_measure return self + + def stream( + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Iterable[document.DocumentSnapshot]: + """Reads the documents in the collection that match this query.""" diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 5b681e7b33..4418323534 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -35,6 +35,8 @@ from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud import firestore_v1 as firestore from google.cloud.firestore_v1.base_query import FieldFilter, And, Or +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.vector import Vector from tests.system.test__helpers import ( FIRESTORE_CREDS, @@ -339,6 +341,47 @@ async def test_document_update_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == expected +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection(client, database): + collection_id = "vector_search" + collection = client.collection(collection_id) + vector_query = collection.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + limit=1, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group(client, database): + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py new file mode 100644 index 0000000000..eae018de30 --- /dev/null +++ b/tests/unit/v1/test_async_vector_query.py @@ -0,0 +1,241 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.firestore_v1.types.query import StructuredQuery +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +from tests.unit.v1._test_helpers import ( + make_async_query, + make_async_client, + make_query, +) +from tests.unit.v1.test_base_query import _make_query_response +from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs + +_PROJECT = "PROJECT" +_TXN_ID = b"\x00\x00\x01-work-\xf2" + + +def _transaction(client): + transaction = client.transaction() + txn_id = _TXN_ID + transaction._id = txn_id + return transaction + + +def _expected_pb(parent, vector_field, vector, distance_type, limit): + query = make_query(parent) + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path=vector_field), + query_vector=encode_value(vector.to_map_value()), + distance_measure=distance_type, + limit=limit, + ) + return expected_pb + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_filter(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_vector_query_collection_group(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection group reference as parent. + collection_group_ref = client.collection_group("dee") + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb = _make_query_response(name="xxx/test_doc", data=data) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + vector_query = collection_group_ref.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + expected_pb.from_ = [ + StructuredQuery.CollectionSelector(collection_id="dee", all_descendants=True) + ] + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_async_query_stream_multiple_empty_response_in_stream(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = AsyncIter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + async_vector_query = parent.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + result = [snapshot async for snapshot in async_vector_query.stream()] + + assert list(result) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": async_vector_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + )