Skip to content

Commit

Permalink
feat: support async Vector Search (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
pl04351820 authored Jul 11, 2024
1 parent 3e5df35 commit 2de1620
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 2 deletions.
31 changes: 31 additions & 0 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@

from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1 import transaction
from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING

if TYPE_CHECKING: # pragma: NO COVER
# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.vector import Vector


class AsyncQuery(BaseQuery):
Expand Down Expand Up @@ -222,6 +225,34 @@ async def get(

return result

def find_nearest(
self,
vector_field: str,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return AsyncVectorQuery(self).find_nearest(
vector_field=vector_field,
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
)

def count(
self, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
Expand Down
127 changes: 127 additions & 0 deletions google/cloud/firestore_v1/async_vector_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import (
BaseQuery,
_query_response_to_snapshot,
_collection_group_query_response_to_snapshot,
)
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
from typing import AsyncGenerator, List, Union, Optional, TypeVar

TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery")


class AsyncVectorQuery(BaseVectorQuery):
"""Represents an async vector query to the Firestore API."""

def __init__(
self,
nested_query: Union[BaseQuery, TAsyncVectorQuery],
) -> None:
"""Presents the vector query.
Args:
nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter.
"""
super(AsyncVectorQuery, self).__init__(nested_query)

async def get(
self,
transaction=None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> List[DocumentSnapshot]:
"""Runs the vector query.
This sends a ``RunQuery`` RPC and returns a list of document messages.
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Returns:
list: The vector query results.
"""
stream_result = self.stream(
transaction=transaction, retry=retry, timeout=timeout
)
result = [snapshot async for snapshot in stream_result]
return result # type: ignore

async def stream(
self,
transaction=None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> AsyncGenerator[async_document.DocumentSnapshot, None]:
"""Reads the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and then returns an iterator which
consumes each document returned in the stream of ``RunQueryResponse``
messages.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
"""
request, expected_prefix, kwargs = self._prep_stream(
transaction,
retry,
timeout,
)

response_iterator = await self._client._firestore_api.run_query(
request=request,
metadata=self._client._rpc_metadata,
**kwargs,
)

async for response in response_iterator:
if self._nested_query._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._nested_query._parent
)
else:
snapshot = _query_response_to_snapshot(
response, self._nested_query._parent, expected_prefix
)
if snapshot is not None:
yield snapshot
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def _to_protobuf(self) -> StructuredQuery:
def find_nearest(
self,
vector_field: str,
queryVector: Vector,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
) -> BaseVectorQuery:
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.types import query
from google.cloud.firestore_v1.vector import Vector
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import document, _helpers


class DistanceMeasure(Enum):
Expand Down Expand Up @@ -117,3 +117,11 @@ def find_nearest(
self._limit = limit
self._distance_measure = distance_measure
return self

def stream(
self,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Iterable[document.DocumentSnapshot]:
"""Reads the documents in the collection that match this query."""
43 changes: 43 additions & 0 deletions tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud import firestore_v1 as firestore
from google.cloud.firestore_v1.base_query import FieldFilter, And, Or
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.vector import Vector

from tests.system.test__helpers import (
FIRESTORE_CREDS,
Expand Down Expand Up @@ -339,6 +341,47 @@ async def test_document_update_w_int_field(client, cleanup, database):
assert snapshot1.to_dict() == expected


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection(client, database):
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
limit=1,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group(client, database):
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_update_document(client, cleanup, database):
Expand Down
Loading

0 comments on commit 2de1620

Please sign in to comment.