From 1c9632e4e3e8a00ad8b6f5fef622a84172aec073 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 18 Oct 2024 07:03:48 -0400 Subject: [PATCH] WIP on grpc asyncio and query_namespaces method This commit squashes several previous commits on this branch to address a GitGuardian security check that continues to fail because an early commit contained a leaked development key. --- .github/actions/create-index/action.yml | 9 + .../test-data-plane-asyncio/action.yaml | 55 ++ .github/workflows/alpha-release.yaml | 27 +- .github/workflows/testing-integration.yaml | 58 ++ .gitignore | 6 +- app.py | 31 - app2.py | 16 - app3.py | 62 -- pinecone/control/pinecone.py | 16 +- pinecone/exceptions/__init__.py | 3 + pinecone/grpc/__init__.py | 1 + pinecone/grpc/grpc_runner.py | 82 ++- pinecone/grpc/index_grpc.py | 60 +- pinecone/grpc/index_grpc_asyncio.py | 604 +++++++++++++++++- pinecone/grpc/pinecone.py | 8 +- pinecone/grpc/query_results.py | 14 + pinecone/grpc/query_results_aggregator.py | 190 ++++++ pinecone/grpc/sparse_vector.py | 6 + pinecone/grpc/utils.py | 27 +- poetry.lock | 117 +++- pyproject.toml | 8 +- scripts/create.py | 12 +- tests/integration/data_asyncio/__init__.py | 0 tests/integration/data_asyncio/conftest.py | 55 ++ tests/integration/data_asyncio/test_upsert.py | 97 +++ .../data_asyncio/test_upsert_errors.py | 234 +++++++ tests/integration/data_asyncio/utils.py | 31 + .../test_query_results_aggregator.py | 553 ++++++++++++++++ 28 files changed, 2141 insertions(+), 241 deletions(-) create mode 100644 .github/actions/test-data-plane-asyncio/action.yaml delete mode 100644 app.py delete mode 100644 app2.py delete mode 100644 app3.py create mode 100644 pinecone/grpc/query_results.py create mode 100644 pinecone/grpc/query_results_aggregator.py create mode 100644 pinecone/grpc/sparse_vector.py create mode 100644 tests/integration/data_asyncio/__init__.py create mode 100644 tests/integration/data_asyncio/conftest.py create mode 100644 tests/integration/data_asyncio/test_upsert.py create mode 100644 tests/integration/data_asyncio/test_upsert_errors.py create mode 100644 tests/integration/data_asyncio/utils.py create mode 100644 tests/unit_grpc/test_query_results_aggregator.py diff --git a/.github/actions/create-index/action.yml b/.github/actions/create-index/action.yml index b81dc1b9..f3140948 100644 --- a/.github/actions/create-index/action.yml +++ b/.github/actions/create-index/action.yml @@ -30,6 +30,15 @@ outputs: index_name: description: 'The name of the index, including randomized suffix' value: ${{ steps.create-index.outputs.index_name }} + index_host: + description: 'The host of the index' + value: ${{ steps.create-index.outputs.index_host }} + index_dimension: + description: 'The dimension of the index' + value: ${{ steps.create-index.outputs.index_dimension }} + index_metric: + description: 'The metric of the index' + value: ${{ steps.create-index.outputs.index_metric }} runs: using: 'composite' diff --git a/.github/actions/test-data-plane-asyncio/action.yaml b/.github/actions/test-data-plane-asyncio/action.yaml new file mode 100644 index 00000000..7761aa29 --- /dev/null +++ b/.github/actions/test-data-plane-asyncio/action.yaml @@ -0,0 +1,55 @@ +name: 'Test Data Plane' +description: 'Runs tests on the Pinecone data plane' + +inputs: + metric: + description: 'The metric of the index' + required: true + dimension: + description: 'The dimension of the index' + required: true + host: + description: 'The host of the index' + required: true + use_grpc: + description: 'Whether to use gRPC or REST' + required: true + freshness_timeout_seconds: + description: 'The number of seconds to wait for the index to become fresh' + required: false + default: '60' + PINECONE_API_KEY: + description: 'The Pinecone API key' + required: true + +outputs: + index_name: + description: 'The name of the index, including randomized suffix' + value: ${{ steps.create-index.outputs.index_name }} + +runs: + using: 'composite' + steps: + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python_version }} + + - name: Setup Poetry + uses: ./.github/actions/setup-poetry + with: + include_grpc: ${{ inputs.use_grpc }} + include_dev: 'true' + + - name: Run data plane tests + id: data-plane-tests + shell: bash + run: poetry run pytest tests/integration/data_asyncio + env: + PINECONE_API_KEY: ${{ inputs.PINECONE_API_KEY }} + USE_GRPC: ${{ inputs.use_grpc }} + METRIC: ${{ inputs.metric }} + INDEX_HOST: ${{ inputs.host }} + DIMENSION: ${{ inputs.dimension }} + SPEC: ${{ inputs.spec }} + FRESHNESS_TIMEOUT_SECONDS: ${{ inputs.freshness_timeout_seconds }} diff --git a/.github/workflows/alpha-release.yaml b/.github/workflows/alpha-release.yaml index 81497630..fe27a4fd 100644 --- a/.github/workflows/alpha-release.yaml +++ b/.github/workflows/alpha-release.yaml @@ -24,22 +24,22 @@ on: default: 'rc1' jobs: - unit-tests: - uses: './.github/workflows/testing-unit.yaml' - secrets: inherit - integration-tests: - uses: './.github/workflows/testing-integration.yaml' - secrets: inherit - dependency-tests: - uses: './.github/workflows/testing-dependency.yaml' - secrets: inherit + # unit-tests: + # uses: './.github/workflows/testing-unit.yaml' + # secrets: inherit + # integration-tests: + # uses: './.github/workflows/testing-integration.yaml' + # secrets: inherit + # dependency-tests: + # uses: './.github/workflows/testing-dependency.yaml' + # secrets: inherit pypi: uses: './.github/workflows/publish-to-pypi.yaml' - needs: - - unit-tests - - integration-tests - - dependency-tests + # needs: + # - unit-tests + # - integration-tests + # - dependency-tests with: isPrerelease: true ref: ${{ inputs.ref }} @@ -49,4 +49,3 @@ jobs: secrets: PYPI_USERNAME: __token__ PYPI_PASSWORD: ${{ secrets.PROD_PYPI_PUBLISH_TOKEN }} - diff --git a/.github/workflows/testing-integration.yaml b/.github/workflows/testing-integration.yaml index 38812e88..ef1ce7f9 100644 --- a/.github/workflows/testing-integration.yaml +++ b/.github/workflows/testing-integration.yaml @@ -32,6 +32,64 @@ jobs: PINECONE_DEBUG_CURL: 'true' PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}' + data-plane-setup: + name: Create index + runs-on: ubuntu-latest + outputs: + index_name: ${{ steps.setup-index.outputs.index_name }} + index_host: ${{ steps.setup-index.outputs.index_host }} + index_dimension: ${{ steps.setup-index.outputs.index_dimension }} + index_metric: ${{ steps.setup-index.outputs.index_metric }} + steps: + - uses: actions/checkout@v4 + - name: Create index + id: setup-index + uses: ./.github/actions/create-index + timeout-minutes: 5 + with: + dimension: 100 + metric: 'cosine' + PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} + + + test-data-plane-asyncio: + name: Data plane asyncio integration tests + runs-on: ubuntu-latest + needs: + - data-plane-setup + outputs: + index_name: ${{ needs.data-plane-setup.outputs.index_name }} + strategy: + fail-fast: false + matrix: + python_version: [3.8, 3.12] + use_grpc: [true] + spec: + - '{ "asyncio": { "environment": "us-east1-gcp" }}' + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/test-data-plane-asyncio + with: + python_version: '${{ matrix.python_version }}' + use_grpc: '${{ matrix.use_grpc }}' + metric: '${{ needs.data-plane-setup.outputs.index_metric }}' + dimension: '${{ needs.data-plane-setup.outputs.index_dimension }}' + host: '${{ needs.data-plane-setup.outputs.index_host }}' + PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}' + freshness_timeout_seconds: 600 + + data-plane-asyncio-cleanup: + name: Deps cleanup + runs-on: ubuntu-latest + needs: + - test-data-plane-asyncio + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/delete-index + with: + index_name: '${{ needs.test-data-plane-asyncio.outputs.index_name }}' + PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}' + data-plane-serverless: name: Data plane serverless integration tests runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 4200d51d..4aaafac2 100644 --- a/.gitignore +++ b/.gitignore @@ -137,7 +137,7 @@ venv.bak/ .ropeproject # pdocs documentation -# We want to exclude any locally generated artifacts, but we rely on +# We want to exclude any locally generated artifacts, but we rely on # keeping documentation assets in the docs/ folder. docs/* !docs/pinecone-python-client-fork.png @@ -155,4 +155,6 @@ dmypy.json *.hdf5 *~ -tests/integration/proxy_config/logs \ No newline at end of file +tests/integration/proxy_config/logs +*.parquet +app*.py diff --git a/app.py b/app.py deleted file mode 100644 index 562241c8..00000000 --- a/app.py +++ /dev/null @@ -1,31 +0,0 @@ -from pinecone.grpc import PineconeGRPC, GRPCClientConfig - -# Initialize a client. An API key must be passed, but the -# value does not matter. -pc = PineconeGRPC(api_key="test_api_key") - -# Target the indexes. Use the host and port number along with disabling tls. -index1 = pc.Index(host="localhost:5081", grpc_config=GRPCClientConfig(secure=False)) -index2 = pc.Index(host="localhost:5082", grpc_config=GRPCClientConfig(secure=False)) - -# You can now perform data plane operations with index1 and index2 - -dimension = 3 - - -def upserts(): - vectors = [] - for i in range(0, 100): - vectors.append((f"vec{i}", [i] * dimension)) - - print(len(vectors)) - - index1.upsert(vectors=vectors, namespace="ns2") - index2.upsert(vectors=vectors, namespace="ns2") - - -upserts() -print(index1.describe_index_stats()) - -print(index1.query(id="vec1", top_k=2, namespace="ns2", include_values=True)) -print(index1.query(id="vec1", top_k=10, namespace="", include_values=True)) diff --git a/app2.py b/app2.py deleted file mode 100644 index c8349e70..00000000 --- a/app2.py +++ /dev/null @@ -1,16 +0,0 @@ -from pinecone.grpc import PineconeGRPC -from pinecone import Pinecone - -pc = Pinecone(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a") -pcg = PineconeGRPC(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a") - -index = pc.Index("jen2") -indexg = pcg.Index(name="jen2", use_asyncio=True) - -# Rest call fails -# print(index.upsert(vectors=[("vec1", [1, 2])])) - -# GRPC succeeds -print(indexg.upsert(vectors=[("vec1", [1, 2])])) - -# print(index.fetch(ids=['vec1'])) diff --git a/app3.py b/app3.py deleted file mode 100644 index 5e49daff..00000000 --- a/app3.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from pinecone.grpc import PineconeGRPC as Pinecone, Vector - -import time -import random -import pandas as pd - - -# Enable gRPC tracing and verbosity for more detailed logs -# os.environ["GRPC_VERBOSITY"] = "DEBUG" -# os.environ["GRPC_TRACE"] = "all" - - -# Generate a large set of vectors (as an example) -def generate_vectors(num_vectors, dimension): - return [ - Vector(id=f"vector_{i}", values=[random.random()] * dimension) for i in range(num_vectors) - ] - - -def load_vectors(): - df = pd.read_parquet("test_records_100k_dim1024.parquet") - df["values"] = df["values"].apply(lambda x: [float(v) for v in x]) - - vectors = [Vector(id=row.id, values=list(row.values)) for row in df.itertuples()] - return vectors - - -async def main(): - # Create a semaphore to limit concurrency (e.g., max 5 concurrent requests) - s = time.time() - # all_vectors = load_vectors() - all_vectors = generate_vectors(1000, 1024) - f = time.time() - print(f"Loaded {len(all_vectors)} vectors in {f-s:.2f} seconds") - - start_time = time.time() - - # Same setup as before... - pc = Pinecone(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a") - index = pc.Index( - # index_host="jen2-dojoi3u.svc.aped-4627-b74a.pinecone.io" - host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", - use_asyncio=True, - ) - - batch_size = 150 - namespace = "asyncio-py7" - res = await index.upsert( - vectors=all_vectors, batch_size=batch_size, namespace=namespace, show_progress=True - ) - - print(res) - - end_time = time.time() - - total_time = end_time - start_time - print(f"All tasks completed in {total_time:.2f} seconds") - print(f"Namespace: {namespace}") - - -asyncio.run(main()) diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index cd49d87f..b2182ec9 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -1,6 +1,6 @@ import time import logging -from typing import Optional, Dict, Any, Union, List, Tuple, Literal +from typing import Optional, Dict, Any, Union, Literal from .index_host_store import IndexHostStore @@ -10,7 +10,12 @@ from pinecone.core.openapi.shared.api_client import ApiClient -from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client +from pinecone.utils import ( + normalize_host, + setup_openapi_client, + build_plugin_setup_client, + parse_non_empty_args, +) from pinecone.core.openapi.control.models import ( CreateCollectionRequest, CreateIndexRequest, @@ -317,9 +322,6 @@ def create_index( api_instance = self.index_api - def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: - return {arg_name: val for arg_name, val in args if val is not None} - if deletion_protection in ["enabled", "disabled"]: dp = DeletionProtection(deletion_protection) else: @@ -329,7 +331,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: if "serverless" in spec: index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"])) elif "pod" in spec: - args_dict = _parse_non_empty_args( + args_dict = parse_non_empty_args( [ ("environment", spec["pod"].get("environment")), ("metadata_config", spec["pod"].get("metadata_config")), @@ -351,7 +353,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region) ) elif isinstance(spec, PodSpec): - args_dict = _parse_non_empty_args( + args_dict = parse_non_empty_args( [ ("replicas", spec.replicas), ("shards", spec.shards), diff --git a/pinecone/exceptions/__init__.py b/pinecone/exceptions/__init__.py index eb0e10fa..5964af37 100644 --- a/pinecone/exceptions/__init__.py +++ b/pinecone/exceptions/__init__.py @@ -12,6 +12,8 @@ ) from .exceptions import PineconeConfigurationError, PineconeProtocolError, ListConversionException +PineconeNotFoundException = NotFoundException + __all__ = [ "PineconeConfigurationError", "PineconeProtocolError", @@ -22,6 +24,7 @@ "PineconeApiKeyError", "PineconeApiException", "NotFoundException", + "PineconeNotFoundException", "UnauthorizedException", "ForbiddenException", "ServiceException", diff --git a/pinecone/grpc/__init__.py b/pinecone/grpc/__init__.py index a027e897..df05cbfe 100644 --- a/pinecone/grpc/__init__.py +++ b/pinecone/grpc/__init__.py @@ -45,6 +45,7 @@ """ from .index_grpc import GRPCIndex +from .index_grpc_asyncio import GRPCIndexAsyncio from .pinecone import PineconeGRPC from .config import GRPCClientConfig from .future import PineconeGrpcFuture diff --git a/pinecone/grpc/grpc_runner.py b/pinecone/grpc/grpc_runner.py index 253a6b33..f70ec36a 100644 --- a/pinecone/grpc/grpc_runner.py +++ b/pinecone/grpc/grpc_runner.py @@ -1,3 +1,4 @@ +import asyncio from functools import wraps from typing import Dict, Tuple, Optional @@ -7,10 +8,19 @@ from .utils import _generate_request_id from .config import GRPCClientConfig from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION -from pinecone.exceptions.exceptions import PineconeException -from grpc import CallCredentials, Compression +from grpc import CallCredentials, Compression, StatusCode +from grpc.aio import AioRpcError from google.protobuf.message import Message +from pinecone.exceptions import ( + PineconeException, + PineconeApiValueError, + PineconeApiException, + UnauthorizedException, + PineconeNotFoundException, + ServiceException, +) + class GrpcRunner: def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig): @@ -49,7 +59,7 @@ def wrapped(): compression=compression, ) except _InactiveRpcError as e: - raise PineconeException(e._state.debug_error_string) from e + self._map_exception(e, e._state.code, e._state.details) return wrapped() @@ -62,22 +72,34 @@ async def run_asyncio( credentials: Optional[CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[Compression] = None, + semaphore: Optional[asyncio.Semaphore] = None, ): @wraps(func) async def wrapped(): user_provided_metadata = metadata or {} _metadata = self._prepare_metadata(user_provided_metadata) try: - return await func( - request, - timeout=timeout, - metadata=_metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression, - ) - except _InactiveRpcError as e: - raise PineconeException(e._state.debug_error_string) from e + if semaphore is not None: + async with semaphore: + return await func( + request, + timeout=timeout, + metadata=_metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + else: + return await func( + request, + timeout=timeout, + metadata=_metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + except AioRpcError as e: + self._map_exception(e, e.code(), e.details()) return await wrapped() @@ -95,3 +117,37 @@ def _prepare_metadata( def _request_metadata(self) -> Dict[str, str]: return {REQUEST_ID: _generate_request_id()} + + def _map_exception(self, e: Exception, code: Optional[StatusCode], details: Optional[str]): + # Client / connection issues + details = details or "" + + if code in [StatusCode.DEADLINE_EXCEEDED]: + raise TimeoutError(details) from e + + # Permissions stuff + if code in [StatusCode.PERMISSION_DENIED, StatusCode.UNAUTHENTICATED]: + raise UnauthorizedException(status=code, reason=details) from e + + # 400ish stuff + if code in [StatusCode.NOT_FOUND]: + raise PineconeNotFoundException(status=code, reason=details) from e + if code in [StatusCode.INVALID_ARGUMENT, StatusCode.OUT_OF_RANGE]: + raise PineconeApiValueError(details) from e + if code in [ + StatusCode.ALREADY_EXISTS, + StatusCode.FAILED_PRECONDITION, + StatusCode.UNIMPLEMENTED, + StatusCode.RESOURCE_EXHAUSTED, + ]: + raise PineconeApiException(status=code, reason=details) from e + + # 500ish stuff + if code in [StatusCode.INTERNAL, StatusCode.UNAVAILABLE]: + raise ServiceException(status=code, reason=details) from e + if code in [StatusCode.UNKNOWN, StatusCode.DATA_LOSS, StatusCode.ABORTED]: + # abandon hope, all ye who enter here + raise PineconeException(code, details) from e + + # If you get here, you're in a bad place + raise PineconeException(code, details) from e diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 95d5846a..3d3eeca8 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -7,13 +7,19 @@ from concurrent.futures import as_completed, Future +from pinecone.utils import parse_non_empty_args from .utils import ( dict_to_proto_struct, parse_fetch_response, parse_query_response, parse_stats_response, + parse_sparse_values_arg, ) from .vector_factory_grpc import VectorFactoryGRPC +from .base import GRPCIndexBase +from .future import PineconeGrpcFuture +from .sparse_vector import SparseVectorTypedDict +from .config import GRPCClientConfig from pinecone.core.openapi.data.models import ( FetchResponse, @@ -39,10 +45,7 @@ from pinecone import Vector as NonGRPCVector from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub -from .base import GRPCIndexBase -from .future import PineconeGrpcFuture -from .config import GRPCClientConfig from pinecone.config import Config from grpc._channel import Channel @@ -52,11 +55,6 @@ _logger = logging.getLogger(__name__) -class SparseVectorTypedDict(TypedDict): - indices: List[int] - values: List[float] - - class GRPCIndex(GRPCIndexBase): """A client for interacting with a Pinecone index via GRPC API.""" @@ -155,7 +153,7 @@ def upsert( vectors = list(map(VectorFactoryGRPC.build, vectors)) if async_req: - args_dict = self._parse_non_empty_args([("namespace", namespace)]) + args_dict = parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict, **kwargs) future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout) return PineconeGrpcFuture(future) @@ -181,7 +179,7 @@ def upsert( def _upsert_batch( self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs ) -> UpsertResponse: - args_dict = self._parse_non_empty_args([("namespace", namespace)]) + args_dict = parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict) return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs) @@ -288,7 +286,7 @@ def delete( else: filter_struct = None - args_dict = self._parse_non_empty_args( + args_dict = parse_non_empty_args( [ ("ids", ids), ("delete_all", delete_all), @@ -329,7 +327,7 @@ def fetch( """ timeout = kwargs.pop("timeout", None) - args_dict = self._parse_non_empty_args([("namespace", namespace)]) + args_dict = parse_non_empty_args([("namespace", namespace)]) request = FetchRequest(ids=ids, **args_dict, **kwargs) @@ -400,8 +398,8 @@ def query( else: filter_struct = None - sparse_vector = self._parse_sparse_values_arg(sparse_vector) - args_dict = self._parse_non_empty_args( + sparse_vector = parse_sparse_values_arg(sparse_vector) + args_dict = parse_non_empty_args( [ ("vector", vector), ("id", id), @@ -516,8 +514,8 @@ def update( set_metadata_struct = None timeout = kwargs.pop("timeout", None) - sparse_values = self._parse_sparse_values_arg(sparse_values) - args_dict = self._parse_non_empty_args( + sparse_values = parse_sparse_values_arg(sparse_values) + args_dict = parse_non_empty_args( [ ("values", values), ("set_metadata", set_metadata_struct), @@ -566,7 +564,7 @@ def list_paginated( Returns: SimpleListResponse object which contains the list of ids, the namespace name, pagination information, and usage showing the number of read_units consumed. """ - args_dict = self._parse_non_empty_args( + args_dict = parse_non_empty_args( [ ("prefix", prefix), ("limit", limit), @@ -645,36 +643,10 @@ def describe_index_stats( filter_struct = dict_to_proto_struct(filter) else: filter_struct = None - args_dict = self._parse_non_empty_args([("filter", filter_struct)]) + args_dict = parse_non_empty_args([("filter", filter_struct)]) timeout = kwargs.pop("timeout", None) request = DescribeIndexStatsRequest(**args_dict) response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout) json_response = json_format.MessageToDict(response) return parse_stats_response(json_response) - - @staticmethod - def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: - return {arg_name: val for arg_name, val in args if val is not None} - - @staticmethod - def _parse_sparse_values_arg( - sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]], - ) -> Optional[GRPCSparseValues]: - if sparse_values is None: - return None - - if isinstance(sparse_values, GRPCSparseValues): - return sparse_values - - if ( - not isinstance(sparse_values, dict) - or "indices" not in sparse_values - or "values" not in sparse_values - ): - raise ValueError( - "Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}." - f"Received: {sparse_values}" - ) - - return GRPCSparseValues(indices=sparse_values["indices"], values=sparse_values["values"]) diff --git a/pinecone/grpc/index_grpc_asyncio.py b/pinecone/grpc/index_grpc_asyncio.py index 2d141803..2dbdda77 100644 --- a/pinecone/grpc/index_grpc_asyncio.py +++ b/pinecone/grpc/index_grpc_asyncio.py @@ -1,27 +1,48 @@ -from typing import Optional, Union, List, Awaitable +from typing import Optional, Union, List, Dict, Awaitable, Any from tqdm.asyncio import tqdm -from asyncio import Semaphore - -from .vector_factory_grpc import VectorFactoryGRPC +import asyncio +from google.protobuf import json_format +from pinecone.core.openapi.data.models import ( + FetchResponse, + QueryResponse, + DescribeIndexStatsResponse, +) +from pinecone.models.list_response import ListResponse as SimpleListResponse from pinecone.core.grpc.protos.vector_service_pb2 import ( Vector as GRPCVector, - QueryVector as GRPCQueryVector, UpsertRequest, UpsertResponse, + DeleteRequest, + QueryRequest, + FetchRequest, + UpdateRequest, + DescribeIndexStatsRequest, + DeleteResponse, + UpdateResponse, SparseValues as GRPCSparseValues, ) -from .base import GRPCIndexBase + from pinecone import Vector as NonGRPCVector from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub from pinecone.utils import parse_non_empty_args -from .config import GRPCClientConfig from pinecone.config import Config from grpc._channel import Channel -__all__ = ["GRPCIndexAsyncio", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"] +from .base import GRPCIndexBase +from .config import GRPCClientConfig +from .sparse_vector import SparseVectorTypedDict +from .utils import ( + dict_to_proto_struct, + parse_fetch_response, + parse_query_response, + parse_stats_response, + parse_sparse_values_arg, +) +from .vector_factory_grpc import VectorFactoryGRPC +from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults class GRPCIndexAsyncio(GRPCIndexBase): @@ -48,53 +69,570 @@ def __init__( def stub_class(self): return VectorServiceStub + def _get_semaphore( + self, + max_concurrent_requests: Optional[int] = None, + semaphore: Optional[asyncio.Semaphore] = None, + ) -> asyncio.Semaphore: + if semaphore is not None and max_concurrent_requests is not None: + raise ValueError("Cannot specify both `max_concurrent_requests` and `semaphore`") + if semaphore is not None: + return semaphore + if max_concurrent_requests is None: + return asyncio.Semaphore(25) + return asyncio.Semaphore(max_concurrent_requests) + async def upsert( self, vectors: Union[List[GRPCVector], List[NonGRPCVector], List[tuple], List[dict]], namespace: Optional[str] = None, batch_size: Optional[int] = None, show_progress: bool = True, + max_concurrent_requests: Optional[int] = None, + semaphore: Optional[asyncio.Semaphore] = None, **kwargs, - ) -> Awaitable[UpsertResponse]: + ) -> UpsertResponse: timeout = kwargs.pop("timeout", None) vectors = list(map(VectorFactoryGRPC.build, vectors)) + semaphore = self._get_semaphore(max_concurrent_requests, semaphore) if batch_size is None: - return await self._upsert_batch(vectors, namespace, timeout=timeout, **kwargs) + return await self._upsert_batch( + vectors=vectors, namespace=namespace, timeout=timeout, semaphore=semaphore, **kwargs + ) - else: - if not isinstance(batch_size, int) or batch_size <= 0: - raise ValueError("batch_size must be a positive integer") + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size must be a positive integer") - semaphore = Semaphore(25) - vector_batches = [ - vectors[i : i + batch_size] for i in range(0, len(vectors), batch_size) - ] - tasks = [ - self._upsert_batch( - vectors=batch, namespace=namespace, timeout=100, semaphore=semaphore - ) - for batch in vector_batches - ] + vector_batches = [vectors[i : i + batch_size] for i in range(0, len(vectors), batch_size)] + tasks = [ + self._upsert_batch( + vectors=batch, semaphore=semaphore, namespace=namespace, timeout=100, **kwargs + ) + for batch in vector_batches + ] + + if namespace is not None: + pbar_desc = f"Upserted vectors in namespace '{namespace}'" + else: + pbar_desc = "Upserted vectors in namespace ''" - return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches") + upserted_count = 0 + with tqdm(total=len(vectors), disable=not show_progress, desc=pbar_desc) as pbar: + for task in asyncio.as_completed(tasks): + res = await task + pbar.update(res.upserted_count) + upserted_count += res.upserted_count + return UpsertResponse(upserted_count=upserted_count) async def _upsert_batch( self, vectors: List[GRPCVector], + semaphore: asyncio.Semaphore, namespace: Optional[str], timeout: Optional[int] = None, - semaphore: Optional[Semaphore] = None, **kwargs, - ) -> Awaitable[UpsertResponse]: + ) -> UpsertResponse: args_dict = parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict) - if semaphore is not None: - async with semaphore: - return await self.runner.run_asyncio( - self.stub.Upsert, request, timeout=timeout, **kwargs - ) + return await self.runner.run_asyncio( + self.stub.Upsert, request, timeout=timeout, semaphore=semaphore, **kwargs + ) + + async def _query( + self, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + namespace: Optional[str] = None, + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + semaphore: Optional[asyncio.Semaphore] = None, + **kwargs, + ) -> Dict[str, Any]: + if vector is not None and id is not None: + raise ValueError("Cannot specify both `id` and `vector`") + + if filter is not None: + filter_struct = dict_to_proto_struct(filter) else: - return await self.runner.run_asyncio( - self.stub.Upsert, request, timeout=timeout, **kwargs + filter_struct = None + + sparse_vector = parse_sparse_values_arg(sparse_vector) + args_dict = parse_non_empty_args( + [ + ("vector", vector), + ("id", id), + ("namespace", namespace), + ("top_k", top_k), + ("filter", filter_struct), + ("include_values", include_values), + ("include_metadata", include_metadata), + ("sparse_vector", sparse_vector), + ] + ) + + request = QueryRequest(**args_dict) + + timeout = kwargs.pop("timeout", None) + semaphore = self._get_semaphore(None, semaphore) + + response = await self.runner.run_asyncio( + self.stub.Query, request, timeout=timeout, semaphore=semaphore + ) + parsed = json_format.MessageToDict(response) + return parsed + + async def query( + self, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + namespace: Optional[str] = None, + top_k: Optional[int] = 10, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + semaphore: Optional[asyncio.Semaphore] = None, + **kwargs, + ) -> QueryResponse: + """ + The Query operation searches a namespace, using a query vector. + It retrieves the ids of the most similar items in a namespace, along with their similarity scores. + + Examples: + >>> await index.query(vector=[1, 2, 3], top_k=10, namespace='my_namespace') + >>> await index.query(id='id1', top_k=10, namespace='my_namespace') + >>> await index.query(vector=[1, 2, 3], top_k=10, namespace='my_namespace', filter={'key': 'value'}) + >>> await index.query(id='id1', top_k=10, namespace='my_namespace', include_metadata=True, include_values=True) + >>> await index.query(vector=[1, 2, 3], sparse_vector={'indices': [1, 2], 'values': [0.2, 0.4]}, + >>> top_k=10, namespace='my_namespace') + >>> await index.query(vector=[1, 2, 3], sparse_vector=GRPCSparseValues([1, 2], [0.2, 0.4]), + >>> top_k=10, namespace='my_namespace') + + Args: + vector (List[float]): The query vector. This should be the same length as the dimension of the index + being queried. Each `query()` request can contain only one of the parameters + `id` or `vector`.. [optional] + id (str): The unique ID of the vector to be used as a query vector. + Each `query()` request can contain only one of the parameters + `vector` or `id`.. [optional] + top_k (int): The number of results to return for each query. Must be an integer greater than 1. + namespace (str): The namespace to fetch vectors from. + If not specified, the default namespace is used. [optional] + filter (Dict[str, Union[str, float, int, bool, List, dict]]): + The filter to apply. You can use vector metadata to limit your search. + See https://www.pinecone.io/docs/metadata-filtering/.. [optional] + include_values (bool): Indicates whether vector values are included in the response. + If omitted the server will use the default value of False [optional] + include_metadata (bool): Indicates whether metadata is included in the response as well as the ids. + If omitted the server will use the default value of False [optional] + sparse_vector: (Union[SparseValues, Dict[str, Union[List[float], List[int]]]]): sparse values of the query vector. + Expected to be either a GRPCSparseValues object or a dict of the form: + {'indices': List[int], 'values': List[float]}, where the lists each have the same length. + + Returns: QueryResponse object which contains the list of the closest vectors as ScoredVector objects, + and namespace name. + """ + # We put everything but the response parsing into the private _query method so + # that we can reuse it when querying over multiple namespaces. Since we need to do + # some work to aggregate and present the results from multiple namespaces in that + # case, we don't want to create a bunch of intermediate openapi QueryResponse + # objects that will just be thrown out in favor of a different presentation of those + # aggregate results. + json_response = await self._query( + vector=vector, + id=id, + namespace=namespace, + top_k=top_k, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + semaphore=semaphore, + **kwargs, + ) + return parse_query_response(json_response, _check_type=False) + + async def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + show_progress: Optional[bool] = True, + max_concurrent_requests: Optional[int] = None, + semaphore: Optional[asyncio.Semaphore] = None, + **kwargs, + ) -> QueryNamespacesResults: + aggregator_lock = asyncio.Lock() + semaphore = self._get_semaphore(max_concurrent_requests, semaphore) + + if len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + # The caller may only want the top_k=1 result across all queries, + # but we need to get at least 2 results from each query in order to + # aggregate them correctly. So we'll temporarily set topK to 2 for the + # subqueries, and then we'll take the topK=1 results from the aggregated + # results. + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + subquery_topk = overall_topk if overall_topk > 2 else 2 + + target_namespaces = set(namespaces) # dedup namespaces + query_tasks = [ + self._query( + vector=vector, + namespace=ns, + top_k=subquery_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + semaphore=semaphore, + **kwargs, ) + for ns in target_namespaces + ] + + with tqdm( + total=len(query_tasks), disable=not show_progress, desc="Querying namespaces" + ) as pbar: + for query_task in asyncio.as_completed(query_tasks): + response = await query_task + pbar.update(1) + async with aggregator_lock: + aggregator.add_results(response) + + final_results = aggregator.get_results() + return final_results + + async def upsert_from_dataframe( + self, + df, + namespace: str = "", + batch_size: int = 500, + use_async_requests: bool = True, + show_progress: bool = True, + ) -> Awaitable[UpsertResponse]: + """Upserts a dataframe into the index. + + Args: + df: A pandas dataframe with the following columns: id, values, sparse_values, and metadata. + namespace: The namespace to upsert into. + batch_size: The number of rows to upsert in a single batch. + use_async_requests: Whether to upsert multiple requests at the same time using asynchronous request mechanism. + Set to `False` + show_progress: Whether to show a progress bar. + """ + # try: + # import pandas as pd + # except ImportError: + # raise RuntimeError( + # "The `pandas` package is not installed. Please install pandas to use `upsert_from_dataframe()`" + # ) + + # if not isinstance(df, pd.DataFrame): + # raise ValueError(f"Only pandas dataframes are supported. Found: {type(df)}") + + # pbar = tqdm(total=len(df), disable=not show_progress, desc="sending upsert requests") + # results = [] + # for chunk in self._iter_dataframe(df, batch_size=batch_size): + # res = self.upsert(vectors=chunk, namespace=namespace, async_req=use_async_requests) + # pbar.update(len(chunk)) + # results.append(res) + + # if use_async_requests: + # cast_results = cast(List[PineconeGrpcFuture], results) + # results = [ + # async_result.result() + # for async_result in tqdm( + # cast_results, disable=not show_progress, desc="collecting async responses" + # ) + # ] + + # upserted_count = 0 + # for res in results: + # if hasattr(res, "upserted_count") and isinstance(res.upserted_count, int): + # upserted_count += res.upserted_count + + # return UpsertResponse(upserted_count=upserted_count) + raise NotImplementedError( + "upsert_from_dataframe is not yet implemented for GRPCIndexAsyncio" + ) + + # @staticmethod + # def _iter_dataframe(df, batch_size): + # for i in range(0, len(df), batch_size): + # batch = df.iloc[i : i + batch_size].to_dict(orient="records") + # yield batch + + async def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + **kwargs, + ) -> Awaitable[DeleteResponse]: + """ + The Delete operation deletes vectors from the index, from a single namespace. + No error raised if the vector id does not exist. + Note: for any delete call, if namespace is not specified, the default namespace is used. + + Delete can occur in the following mutual exclusive ways: + 1. Delete by ids from a single namespace + 2. Delete all vectors from a single namespace by setting delete_all to True + 3. Delete all vectors from a single namespace by specifying a metadata filter + (note that for this option delete all must be set to False) + + Examples: + >>> await index.delete(ids=['id1', 'id2'], namespace='my_namespace') + >>> await index.delete(delete_all=True, namespace='my_namespace') + >>> await index.delete(filter={'key': 'value'}, namespace='my_namespace', async_req=True) + + Args: + ids (List[str]): Vector ids to delete [optional] + delete_all (bool): This indicates that all vectors in the index namespace should be deleted.. [optional] + Default is False. + namespace (str): The namespace to delete vectors from [optional] + If not specified, the default namespace is used. + filter (Dict[str, Union[str, float, int, bool, List, dict]]): + If specified, the metadata filter here will be used to select the vectors to delete. + This is mutually exclusive with specifying ids to delete in the ids param or using delete_all=True. + See https://www.pinecone.io/docs/metadata-filtering/.. [optional] + + Returns: DeleteResponse (contains no data) or a PineconeGrpcFuture object if async_req is True. + """ + if filter is not None: + filter_struct = dict_to_proto_struct(filter) + else: + filter_struct = None + + args_dict = parse_non_empty_args( + [ + ("ids", ids), + ("delete_all", delete_all), + ("namespace", namespace), + ("filter", filter_struct), + ] + ) + timeout = kwargs.pop("timeout", None) + + request = DeleteRequest(**args_dict, **kwargs) + return await self.runner.run_asyncio(self.stub.Delete, request, timeout=timeout) + + async def fetch( + self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs + ) -> Awaitable[FetchResponse]: + """ + The fetch operation looks up and returns vectors, by ID, from a single namespace. + The returned vectors include the vector data and/or metadata. + + Examples: + >>> await index.fetch(ids=['id1', 'id2'], namespace='my_namespace') + >>> await index.fetch(ids=['id1', 'id2']) + + Args: + ids (List[str]): The vector IDs to fetch. + namespace (str): The namespace to fetch vectors from. + If not specified, the default namespace is used. [optional] + + Returns: FetchResponse object which contains the list of Vector objects, and namespace name. + """ + timeout = kwargs.pop("timeout", None) + + args_dict = parse_non_empty_args([("namespace", namespace)]) + + request = FetchRequest(ids=ids, **args_dict, **kwargs) + response = await self.runner.run_asyncio(self.stub.Fetch, request, timeout=timeout) + json_response = json_format.MessageToDict(response) + return parse_fetch_response(json_response) + + async def update( + self, + id: str, + values: Optional[List[float]] = None, + set_metadata: Optional[ + Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]] + ] = None, + namespace: Optional[str] = None, + sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + **kwargs, + ) -> Awaitable[UpdateResponse]: + """ + The Update operation updates vector in a namespace. + If a value is included, it will overwrite the previous value. + If a set_metadata is included, the values of the fields specified in it will be added or overwrite the previous value. + + Examples: + >>> await index.update(id='id1', values=[1, 2, 3], namespace='my_namespace') + >>> await index.update(id='id1', set_metadata={'key': 'value'}, namespace='my_namespace', async_req=True) + >>> await index.update(id='id1', values=[1, 2, 3], sparse_values={'indices': [1, 2], 'values': [0.2, 0.4]}, + >>> namespace='my_namespace') + >>> await index.update(id='id1', values=[1, 2, 3], sparse_values=GRPCSparseValues(indices=[1, 2], values=[0.2, 0.4]), + >>> namespace='my_namespace') + + Args: + id (str): Vector's unique id. + values (List[float]): vector values to set. [optional] + set_metadata (Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]]): + metadata to set for vector. [optional] + namespace (str): Namespace name where to update the vector.. [optional] + sparse_values: (Dict[str, Union[List[float], List[int]]]): sparse values to update for the vector. + Expected to be either a GRPCSparseValues object or a dict of the form: + {'indices': List[int], 'values': List[float]} where the lists each have the same length. + + + Returns: UpdateResponse (contains no data) or a PineconeGrpcFuture object if async_req is True. + """ + if set_metadata is not None: + set_metadata_struct = dict_to_proto_struct(set_metadata) + else: + set_metadata_struct = None + + timeout = kwargs.pop("timeout", None) + sparse_values = parse_sparse_values_arg(sparse_values) + args_dict = parse_non_empty_args( + [ + ("values", values), + ("set_metadata", set_metadata_struct), + ("namespace", namespace), + ("sparse_values", sparse_values), + ] + ) + + request = UpdateRequest(id=id, **args_dict) + return await self.runner.run_asyncio(self.stub.Update, request, timeout=timeout) + + async def list_paginated( + self, + prefix: Optional[str] = None, + limit: Optional[int] = None, + pagination_token: Optional[str] = None, + namespace: Optional[str] = None, + **kwargs, + ) -> Awaitable[SimpleListResponse]: + """ + The list_paginated operation finds vectors based on an id prefix within a single namespace. + It returns matching ids in a paginated form, with a pagination token to fetch the next page of results. + This id list can then be passed to fetch or delete operations, depending on your use case. + + Consider using the `list` method to avoid having to handle pagination tokens manually. + + Examples: + >>> results = index.list_paginated(prefix='99', limit=5, namespace='my_namespace') + >>> [v.id for v in results.vectors] + ['99', '990', '991', '992', '993'] + >>> results.pagination.next + eyJza2lwX3Bhc3QiOiI5OTMiLCJwcmVmaXgiOiI5OSJ9 + >>> next_results = index.list_paginated(prefix='99', limit=5, namespace='my_namespace', pagination_token=results.pagination.next) + + Args: + prefix (Optional[str]): The id prefix to match. If unspecified, an empty string prefix will + be used with the effect of listing all ids in a namespace [optional] + limit (Optional[int]): The maximum number of ids to return. If unspecified, the server will use a default value. [optional] + pagination_token (Optional[str]): A token needed to fetch the next page of results. This token is returned + in the response if additional results are available. [optional] + namespace (Optional[str]): The namespace to fetch vectors from. If not specified, the default namespace is used. [optional] + + Returns: SimpleListResponse object which contains the list of ids, the namespace name, pagination information, and usage showing the number of read_units consumed. + """ + # args_dict = parse_non_empty_args( + # [ + # ("prefix", prefix), + # ("limit", limit), + # ("namespace", namespace), + # ("pagination_token", pagination_token), + # ] + # ) + # request = ListRequest(**args_dict, **kwargs) + # timeout = kwargs.pop("timeout", None) + # response = self.runner.run(self.stub.List, request, timeout=timeout) + + # if response.pagination and response.pagination.next != "": + # pagination = Pagination(next=response.pagination.next) + # else: + # pagination = None + + # return SimpleListResponse( + # namespace=response.namespace, vectors=response.vectors, pagination=pagination + # ) + raise NotImplementedError("list_paginated is not yet implemented for GRPCIndexAsyncio") + + async def list(self, **kwargs): + """ + The list operation accepts all of the same arguments as list_paginated, and returns a generator that yields + a list of the matching vector ids in each page of results. It automatically handles pagination tokens on your + behalf. + + Examples: + >>> for ids in index.list(prefix='99', limit=5, namespace='my_namespace'): + >>> print(ids) + ['99', '990', '991', '992', '993'] + ['994', '995', '996', '997', '998'] + ['999'] + + Args: + prefix (Optional[str]): The id prefix to match. If unspecified, an empty string prefix will + be used with the effect of listing all ids in a namespace [optional] + limit (Optional[int]): The maximum number of ids to return. If unspecified, the server will use a default value. [optional] + pagination_token (Optional[str]): A token needed to fetch the next page of results. This token is returned + in the response if additional results are available. [optional] + namespace (Optional[str]): The namespace to fetch vectors from. If not specified, the default namespace is used. [optional] + """ + # done = False + # while not done: + # try: + # results = self.list_paginated(**kwargs) + # except Exception as e: + # raise e + + # if len(results.vectors) > 0: + # yield [v.id for v in results.vectors] + + # if results.pagination and results.pagination.next: + # kwargs.update({"pagination_token": results.pagination.next}) + # else: + # done = True + raise NotImplementedError("list is not yet implemented for GRPCIndexAsyncio") + + async def describe_index_stats( + self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs + ) -> Awaitable[DescribeIndexStatsResponse]: + """ + The DescribeIndexStats operation returns statistics about the index's contents. + For example: The vector count per namespace and the number of dimensions. + + Examples: + >>> await index.describe_index_stats() + >>> await index.describe_index_stats(filter={'key': 'value'}) + + Args: + filter (Dict[str, Union[str, float, int, bool, List, dict]]): + If this parameter is present, the operation only returns statistics for vectors that satisfy the filter. + See https://www.pinecone.io/docs/metadata-filtering/.. [optional] + + Returns: DescribeIndexStatsResponse object which contains stats about the index. + """ + if filter is not None: + filter_struct = dict_to_proto_struct(filter) + else: + filter_struct = None + args_dict = parse_non_empty_args([("filter", filter_struct)]) + timeout = kwargs.pop("timeout", None) + + request = DescribeIndexStatsRequest(**args_dict) + response = await self.runner.run_asyncio( + self.stub.DescribeIndexStats, request, timeout=timeout + ) + json_response = json_format.MessageToDict(response) + return parse_stats_response(json_response) diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index 224a9167..0be0fd23 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -48,7 +48,7 @@ class PineconeGRPC(Pinecone): """ - def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs): + def Index(self, name: str = "", host: str = "", **kwargs): """ Target an index for data operations. @@ -119,6 +119,12 @@ def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs): index.query(vector=[...], top_k=10) ``` """ + return self._init_index(name=name, host=host, use_asyncio=False, **kwargs) + + def AsyncioIndex(self, name: str = "", host: str = "", **kwargs): + return self._init_index(name=name, host=host, use_asyncio=True, **kwargs) + + def _init_index(self, name: str, host: str, use_asyncio=False, **kwargs): if name == "" and host == "": raise ValueError("Either name or host must be specified") diff --git a/pinecone/grpc/query_results.py b/pinecone/grpc/query_results.py new file mode 100644 index 00000000..b2201b50 --- /dev/null +++ b/pinecone/grpc/query_results.py @@ -0,0 +1,14 @@ +from typing import TypedDict, List, Dict, Any + + +class ScoredVectorTypedDict(TypedDict): + id: str + score: float + values: List[float] + metadata: dict + + +class QueryResultsTypedDict(TypedDict): + matches: List[ScoredVectorTypedDict] + namespace: str + usage: Dict[str, Any] diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py new file mode 100644 index 00000000..e1608c1f --- /dev/null +++ b/pinecone/grpc/query_results_aggregator.py @@ -0,0 +1,190 @@ +from typing import List, Tuple, Optional, Any, Dict +import json +import heapq +from pinecone.core.openapi.data.models import Usage + +from dataclasses import dataclass, asdict + + +@dataclass +class ScoredVectorWithNamespace: + namespace: str + score: float + id: str + values: List[float] + sparse_values: dict + metadata: dict + + def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): + json_vector = aggregate_results_heap_tuple[2] + self.namespace = aggregate_results_heap_tuple[3] + self.id = json_vector.get("id") # type: ignore + self.score = json_vector.get("score") # type: ignore + self.values = json_vector.get("values") # type: ignore + self.sparse_values = json_vector.get("sparse_values", None) # type: ignore + self.metadata = json_vector.get("metadata", None) # type: ignore + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") + + def __repr__(self): + return json.dumps(self._truncate(asdict(self)), indent=4) + + def __json__(self): + return self._truncate(asdict(self)) + + def _truncate(self, obj, max_items=2): + """ + Recursively traverse and truncate lists that exceed max_items length. + Only display the "... X more" message if at least 2 elements are hidden. + """ + if obj is None: + return None # Skip None values + elif isinstance(obj, list): + filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] + if len(filtered_list) > max_items: + # Show the truncation message only if more than 1 item is hidden + remaining_items = len(filtered_list) - max_items + if remaining_items > 1: + return filtered_list[:max_items] + [f"... {remaining_items} more"] + else: + # If only 1 item remains, show it + return filtered_list + return filtered_list + elif isinstance(obj, dict): + # Recursively process dictionaries, omitting None values + return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} + return obj + + +@dataclass +class QueryNamespacesResults: + usage: Usage + matches: List[ScoredVectorWithNamespace] + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(f"'{key}' not found in QueryNamespacesResults") + + def __repr__(self): + return json.dumps( + { + "usage": self.usage.to_dict(), + "matches": [match.__json__() for match in self.matches], + }, + indent=4, + ) + + +class QueryResultsAggregationEmptyResultsError(Exception): + def __init__(self, namespace: str): + super().__init__( + f"Query results for namespace '{namespace}' were empty. Check that you have upserted vectors into this namespace (see describe_index_stats) and that the namespace name is spelled correctly." + ) + + +class QueryResultsAggregregatorNotEnoughResultsError(Exception): + def __init__(self, num_results: int): + super().__init__( + "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." + ) + + +class QueryResultsAggregatorInvalidTopKError(Exception): + def __init__(self, top_k: int): + super().__init__( + f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." + ) + + +class QueryResultsAggregator: + def __init__(self, top_k: int): + if top_k < 2: + raise QueryResultsAggregatorInvalidTopKError(top_k) + self.top_k = top_k + self.usage_read_units = 0 + self.heap: List[Tuple[float, int, object, str]] = [] + self.insertion_counter = 0 + self.is_dotproduct = None + self.read = False + self.final_results: Optional[QueryNamespacesResults] = None + + def _is_dotproduct_index(self, matches): + # The interpretation of the score depends on the similar metric used. + # Unlike other index types, in indexes configured for dotproduct, + # a higher score is better. We have to infer this is the case by inspecting + # the order of the scores in the results. + for i in range(1, len(matches)): + if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase + return False + return True + + def _dotproduct_heap_item(self, match, ns): + return (match.get("score"), -self.insertion_counter, match, ns) + + def _non_dotproduct_heap_item(self, match, ns): + return (-match.get("score"), -self.insertion_counter, match, ns) + + def _process_matches(self, matches, ns, heap_item_fn): + for match in matches: + self.insertion_counter += 1 + if len(self.heap) < self.top_k: + heapq.heappush(self.heap, heap_item_fn(match, ns)) + else: + # Assume we have dotproduct scores sorted in descending order + if self.is_dotproduct and match["score"] < self.heap[0][0]: + # No further matches can improve the top-K heap + break + elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: + # No further matches can improve the top-K heap + break + heapq.heappushpop(self.heap, heap_item_fn(match, ns)) + + def add_results(self, results: Dict[str, Any]): + if self.read: + # This is mainly just to sanity check in test cases which get quite confusing + # if you read results twice due to the heap being emptied when constructing + # the ordered results. + raise ValueError("Results have already been read. Cannot add more results.") + + matches = results.get("matches", []) + ns: str = results.get("namespace", "") + self.usage_read_units += results.get("usage", {}).get("readUnits", 0) + + if len(matches) == 0: + return + + if self.is_dotproduct is None: + if len(matches) == 1: + # This condition should match the second time we add results containing + # only one match. We need at least two matches in a single response in order + # to infer the similarity metric + raise QueryResultsAggregregatorNotEnoughResultsError(len(matches)) + self.is_dotproduct = self._is_dotproduct_index(matches) + + if self.is_dotproduct: + self._process_matches(matches, ns, self._dotproduct_heap_item) + else: + self._process_matches(matches, ns, self._non_dotproduct_heap_item) + + def get_results(self) -> QueryNamespacesResults: + if self.read: + if self.final_results is not None: + return self.final_results + else: + # I don't think this branch can ever actually be reached, but the type checker disagrees + raise ValueError("Results have already been read. Cannot get results again.") + self.read = True + + self.final_results = QueryNamespacesResults( + usage=Usage(read_units=self.usage_read_units), + matches=[ + ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) + ][::-1], + ) + return self.final_results diff --git a/pinecone/grpc/sparse_vector.py b/pinecone/grpc/sparse_vector.py new file mode 100644 index 00000000..03faf9be --- /dev/null +++ b/pinecone/grpc/sparse_vector.py @@ -0,0 +1,6 @@ +from typing import TypedDict, List + + +class SparseVectorTypedDict(TypedDict): + indices: List[int] + values: List[float] diff --git a/pinecone/grpc/utils.py b/pinecone/grpc/utils.py index 452573c1..c506c7dd 100644 --- a/pinecone/grpc/utils.py +++ b/pinecone/grpc/utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from google.protobuf import json_format from google.protobuf.message import Message @@ -14,6 +14,8 @@ DescribeIndexStatsResponse, NamespaceSummary, ) +from pinecone.core.grpc.protos.vector_service_pb2 import SparseValues as GRPCSparseValues +from .sparse_vector import SparseVectorTypedDict from google.protobuf.struct_pb2 import Struct @@ -26,6 +28,7 @@ def normalize_endpoint(endpoint: str) -> str: grpc_host = endpoint.replace("https://", "") if ":" not in grpc_host: grpc_host = f"{grpc_host}:443" + return grpc_host def dict_to_proto_struct(d: Optional[dict]) -> "Struct": @@ -115,3 +118,25 @@ def parse_stats_response(response: dict): total_vector_count=total_vector_count, _check_type=False, ) + + +def parse_sparse_values_arg( + sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]], +) -> Optional[GRPCSparseValues]: + if sparse_values is None: + return None + + if isinstance(sparse_values, GRPCSparseValues): + return sparse_values + + if ( + not isinstance(sparse_values, dict) + or "indices" not in sparse_values + or "values" not in sparse_values + ): + raise ValueError( + "Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}." + f"Received: {sparse_values}" + ) + + return GRPCSparseValues(indices=sparse_values["indices"], values=sparse_values["values"]) diff --git a/poetry.lock b/poetry.lock index c96def98..6ae08494 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,27 @@ # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +[[package]] +name = "anyio" +version = "4.5.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.5.2-py3-none-any.whl", hash = "sha256:c011ee36bc1e8ba40e5a81cb9df91925c218fe9b778554e0b56a21e1b5d4716f"}, + {file = "anyio-4.5.2.tar.gz", hash = "sha256:23009af4ed04ce05991845451e11ef02fc7c5ed29179ac9a420e5ad0ac7ddc5b"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] + [[package]] name = "astunparse" version = "1.6.3" @@ -516,6 +538,63 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.67.1)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.6" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.27.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "identify" version = "2.6.1" @@ -1158,13 +1237,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pytest" -version = "8.0.0" +version = "8.3.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.0.0-py3-none-any.whl", hash = "sha256:50fb9cbe836c3f20f0dfa99c565201fb75dc54c8d76373cd1bde06b06657bdb6"}, - {file = "pytest-8.0.0.tar.gz", hash = "sha256:249b1b0864530ba251b7438274c4d251c58d868edaaec8762893ad4a0d71c36c"}, + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, ] [package.dependencies] @@ -1172,28 +1251,29 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.3.0,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} +pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-asyncio" -version = "0.15.1" -description = "Pytest support for asyncio." +version = "0.24.0" +description = "Pytest support for asyncio" optional = false -python-versions = ">= 3.6" +python-versions = ">=3.8" files = [ - {file = "pytest-asyncio-0.15.1.tar.gz", hash = "sha256:2564ceb9612bbd560d19ca4b41347b54e7835c2f792c504f698e05395ed63f6f"}, - {file = "pytest_asyncio-0.15.1-py3-none-any.whl", hash = "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea"}, + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, ] [package.dependencies] -pytest = ">=5.4.0" +pytest = ">=8.2,<9" [package.extras] -testing = ["coverage", "hypothesis (>=5.7.1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "pytest-benchmark" @@ -1402,6 +1482,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "tomli" version = "2.1.0" diff --git a/pyproject.toml b/pyproject.toml index 7036205e..c0875105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,8 +77,8 @@ numpy = [ ] pandas = ">=1.3.5" pdoc = "^14.1.0" -pytest = "8.0.0" -pytest-asyncio = "0.15.1" +pytest = "8.3.3" +pytest-asyncio = "0.24.0" pytest-cov = "2.10.1" pytest-mock = "3.6.1" pytest-timeout = "2.2.0" @@ -96,6 +96,10 @@ grpc = ["grpcio", "googleapis-common-protos", "lz4", "protobuf", "protoc-gen-ope requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" +[tool.pytest.ini_options] +asyncio_mode = 'auto' +# asyncio_default_fixture_loop_scope = 'session' + [tool.ruff] exclude = [ ".eggs", diff --git a/scripts/create.py b/scripts/create.py index 05a12c36..35538201 100644 --- a/scripts/create.py +++ b/scripts/create.py @@ -59,14 +59,22 @@ def generate_index_name(test_name: str) -> str: def main(): pc = Pinecone(api_key=read_env_var("PINECONE_API_KEY")) + index_name = generate_index_name(read_env_var("NAME_PREFIX") + random_string(20)) + dimension = int(read_env_var("DIMENSION")) + metric = read_env_var("METRIC") + pc.create_index( name=index_name, - metric=read_env_var("METRIC"), - dimension=int(read_env_var("DIMENSION")), + metric=metric, + dimension=dimension, spec={"serverless": {"cloud": read_env_var("CLOUD"), "region": read_env_var("REGION")}}, ) + desc = pc.describe_index(index_name) write_gh_output("index_name", index_name) + write_gh_output("index_host", desc.host) + write_gh_output("index_metric", metric) + write_gh_output("index_dimension", dimension) if __name__ == "__main__": diff --git a/tests/integration/data_asyncio/__init__.py b/tests/integration/data_asyncio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/data_asyncio/conftest.py b/tests/integration/data_asyncio/conftest.py new file mode 100644 index 00000000..dd318f2d --- /dev/null +++ b/tests/integration/data_asyncio/conftest.py @@ -0,0 +1,55 @@ +import pytest +import os +from ..helpers import get_environment_var, random_string + + +@pytest.fixture(scope="session") +def api_key(): + return get_environment_var("PINECONE_API_KEY") + + +@pytest.fixture(scope="session") +def host(): + return get_environment_var("INDEX_HOST") + + +@pytest.fixture(scope="session") +def dimension(): + return int(get_environment_var("DIMENSION")) + + +def use_grpc(): + return os.environ.get("USE_GRPC", "false") == "true" + + +def build_client(api_key): + if use_grpc(): + from pinecone.grpc import PineconeGRPC + + return PineconeGRPC(api_key=api_key) + else: + from pinecone import Pinecone + + return Pinecone( + api_key=api_key, additional_headers={"sdk-test-suite": "pinecone-python-client"} + ) + + +@pytest.fixture(scope="session") +async def pc(api_key): + return build_client(api_key=api_key) + + +@pytest.fixture(scope="session") +async def asyncio_idx(pc, host): + return pc.AsyncioIndex(host=host) + + +@pytest.fixture(scope="session") +async def namespace(): + return random_string(10) + + +@pytest.fixture(scope="session") +async def list_namespace(): + return random_string(10) diff --git a/tests/integration/data_asyncio/test_upsert.py b/tests/integration/data_asyncio/test_upsert.py new file mode 100644 index 00000000..a85ab2d4 --- /dev/null +++ b/tests/integration/data_asyncio/test_upsert.py @@ -0,0 +1,97 @@ +import pytest +from pinecone import Vector +from .conftest import use_grpc +from ..helpers import random_string +from .utils import build_asyncio_idx, embedding_values, poll_for_freshness + + +@pytest.mark.parametrize("target_namespace", ["", random_string(20)]) +@pytest.mark.skipif(use_grpc() == False, reason="Currently only GRPC supports asyncio") +async def test_upsert_to_default_namespace(host, dimension, target_namespace): + asyncio_idx = build_asyncio_idx(host) + + def emb(): + return embedding_values(dimension) + + # Upsert with tuples + await asyncio_idx.upsert( + vectors=[("1", emb()), ("2", emb()), ("3", emb())], namespace=target_namespace + ) + + # Upsert with objects + await asyncio_idx.upsert( + vectors=[ + Vector(id="4", values=emb()), + Vector(id="5", values=emb()), + Vector(id="6", values=emb()), + ], + namespace=target_namespace, + ) + + # Upsert with dict + await asyncio_idx.upsert( + vectors=[ + {"id": "7", "values": emb()}, + {"id": "8", "values": emb()}, + {"id": "9", "values": emb()}, + ], + namespace=target_namespace, + ) + + await poll_for_freshness(asyncio_idx, target_namespace, 9) + + # # Check the vector count reflects some data has been upserted + stats = await asyncio_idx.describe_index_stats() + assert stats.total_vector_count >= 9 + # default namespace could have other stuff from other tests + if target_namespace != "": + assert stats.namespaces[target_namespace].vector_count == 9 + + +# @pytest.mark.parametrize("target_namespace", [ +# "", +# random_string(20), +# ]) +# @pytest.mark.skipif( +# os.getenv("METRIC") != "dotproduct", reason="Only metric=dotprodouct indexes support hybrid" +# ) +# async def test_upsert_to_namespace_with_sparse_embedding_values(pc, host, dimension, target_namespace): +# asyncio_idx = pc.AsyncioIndex(host=host) + +# # Upsert with sparse values object +# await asyncio_idx.upsert( +# vectors=[ +# Vector( +# id="1", +# values=embedding_values(dimension), +# sparse_values=SparseValues(indices=[0, 1], values=embedding_values()), +# ) +# ], +# namespace=target_namespace, +# ) + +# # Upsert with sparse values dict +# await asyncio_idx.upsert( +# vectors=[ +# { +# "id": "2", +# "values": embedding_values(dimension), +# "sparse_values": {"indices": [0, 1], "values": embedding_values()}, +# }, +# { +# "id": "3", +# "values": embedding_values(dimension), +# "sparse_values": {"indices": [0, 1], "values": embedding_values()}, +# }, +# ], +# namespace=target_namespace, +# ) + +# await poll_for_freshness(asyncio_idx, target_namespace, 9) + +# # Check the vector count reflects some data has been upserted +# stats = await asyncio_idx.describe_index_stats() +# assert stats.total_vector_count >= 9 + +# if (target_namespace != ""): +# assert stats.namespaces[target_namespace].vector_count == 9 diff --git a/tests/integration/data_asyncio/test_upsert_errors.py b/tests/integration/data_asyncio/test_upsert_errors.py new file mode 100644 index 00000000..dcf70a26 --- /dev/null +++ b/tests/integration/data_asyncio/test_upsert_errors.py @@ -0,0 +1,234 @@ +import os +import pytest +from pinecone.grpc import Vector, SparseValues +from ..helpers import fake_api_key +from .utils import build_asyncio_idx, embedding_values +from pinecone import PineconeException, PineconeApiValueError +from pinecone.grpc import PineconeGRPC as Pinecone + + +class TestUpsertApiKeyMissing: + async def test_upsert_fails_when_api_key_invalid(self, host): + with pytest.raises(PineconeException): + pc = Pinecone( + api_key=fake_api_key(), + additional_headers={"sdk-test-suite": "pinecone-python-client"}, + ) + asyncio_idx = pc.AsyncioIndex(host=host) + await asyncio_idx.upsert( + vectors=[ + Vector(id="1", values=embedding_values()), + Vector(id="2", values=embedding_values()), + ] + ) + + @pytest.mark.skipif( + os.getenv("USE_GRPC") != "true", reason="Only test grpc client when grpc extras" + ) + async def test_upsert_fails_when_api_key_invalid_grpc(self, host): + with pytest.raises(PineconeException): + from pinecone.grpc import PineconeGRPC + + pc = PineconeGRPC(api_key=fake_api_key()) + asyncio_idx = pc.AsyncioIndex(host=host) + await asyncio_idx.upsert( + vectors=[ + Vector(id="1", values=embedding_values()), + Vector(id="2", values=embedding_values()), + ] + ) + + +class TestUpsertFailsWhenDimensionMismatch: + async def test_upsert_fails_when_dimension_mismatch_objects(self, host): + with pytest.raises(PineconeApiValueError): + asyncio_idx = build_asyncio_idx(host) + await asyncio_idx.upsert( + vectors=[ + Vector(id="1", values=embedding_values(2)), + Vector(id="2", values=embedding_values(3)), + ] + ) + + async def test_upsert_fails_when_dimension_mismatch_tuples(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[("1", embedding_values(2)), ("2", embedding_values(3))] + ) + + async def test_upsert_fails_when_dimension_mismatch_dicts(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[ + {"id": "1", "values": embedding_values(2)}, + {"id": "2", "values": embedding_values(3)}, + ] + ) + + +@pytest.mark.skipif( + os.getenv("METRIC") != "dotproduct", reason="Only metric=dotprodouct indexes support hybrid" +) +class TestUpsertFailsSparseValuesDimensionMismatch: + async def test_upsert_fails_when_sparse_values_indices_values_mismatch_objects(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[ + Vector( + id="1", + values=[0.1, 0.1], + sparse_values=SparseValues(indices=[0], values=[0.5, 0.5]), + ) + ] + ) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[ + Vector( + id="1", + values=[0.1, 0.1], + sparse_values=SparseValues(indices=[0, 1], values=[0.5]), + ) + ] + ) + + async def test_upsert_fails_when_sparse_values_in_tuples(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(ValueError): + await asyncio_idx.upsert( + vectors=[ + ("1", SparseValues(indices=[0], values=[0.5])), + ("2", SparseValues(indices=[0, 1, 2], values=[0.5, 0.5, 0.5])), + ] + ) + + async def test_upsert_fails_when_sparse_values_indices_values_mismatch_dicts(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[ + { + "id": "1", + "values": [0.2, 0.2], + "sparse_values": SparseValues(indices=[0], values=[0.5, 0.5]), + } + ] + ) + with pytest.raises(PineconeException): + await asyncio_idx.upsert( + vectors=[ + { + "id": "1", + "values": [0.1, 0.2], + "sparse_values": SparseValues(indices=[0, 1], values=[0.5]), + } + ] + ) + + +class TestUpsertFailsWhenValuesMissing: + async def test_upsert_fails_when_values_missing_objects(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeApiValueError): + await asyncio_idx.upsert(vectors=[Vector(id="1"), Vector(id="2")]) + + async def test_upsert_fails_when_values_missing_tuples(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(ValueError): + await asyncio_idx.upsert(vectors=[("1",), ("2",)]) + + async def test_upsert_fails_when_values_missing_dicts(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(ValueError): + await asyncio_idx.upsert(vectors=[{"id": "1"}, {"id": "2"}]) + + +class TestUpsertFailsWhenValuesWrongType: + async def test_upsert_fails_when_values_wrong_type_objects(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(TypeError): + await asyncio_idx.upsert( + vectors=[Vector(id="1", values="abc"), Vector(id="2", values="def")] + ) + + async def test_upsert_fails_when_values_wrong_type_tuples(self, host): + asyncio_idx = build_asyncio_idx(host) + if os.environ.get("USE_GRPC", "false") == "true": + expected_exception = TypeError + else: + expected_exception = PineconeException + + with pytest.raises(expected_exception): + await asyncio_idx.upsert(vectors=[("1", "abc"), ("2", "def")]) + + async def test_upsert_fails_when_values_wrong_type_dicts(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(TypeError): + await asyncio_idx.upsert( + vectors=[{"id": "1", "values": "abc"}, {"id": "2", "values": "def"}] + ) + + +class TestUpsertFailsWhenVectorsMissing: + async def test_upsert_fails_when_vectors_empty(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(PineconeException): + await asyncio_idx.upsert(vectors=[]) + + async def test_upsert_fails_when_vectors_wrong_type(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(ValueError): + await asyncio_idx.upsert(vectors="abc") + + async def test_upsert_fails_when_vectors_missing(self, host): + asyncio_idx = build_asyncio_idx(host) + with pytest.raises(TypeError): + await asyncio_idx.upsert() + + +# class TestUpsertIdMissing: +# async def test_upsert_fails_when_id_is_missing_objects(self, host): +# with pytest.raises(TypeError): +# idx.upsert( +# vectors=[ +# Vector(id="1", values=embedding_values()), +# Vector(values=embedding_values()), +# ] +# ) + +# async def test_upsert_fails_when_id_is_missing_tuples(self, host): +# with pytest.raises(ValueError): +# idx.upsert(vectors=[("1", embedding_values()), (embedding_values())]) + +# async def test_upsert_fails_when_id_is_missing_dicts(self, host): +# with pytest.raises(ValueError): +# idx.upsert( +# vectors=[{"id": "1", "values": embedding_values()}, {"values": embedding_values()}] +# ) + + +# class TestUpsertIdWrongType: +# async def test_upsert_fails_when_id_wrong_type_objects(self, host): +# with pytest.raises(Exception): +# idx.upsert( +# vectors=[ +# Vector(id="1", values=embedding_values()), +# Vector(id=2, values=embedding_values()), +# ] +# ) + +# async def test_upsert_fails_when_id_wrong_type_tuples(self, host): +# with pytest.raises(Exception): +# idx.upsert(vectors=[("1", embedding_values()), (2, embedding_values())]) + +# async def test_upsert_fails_when_id_wrong_type_dicts(self, host): +# with pytest.raises(Exception): +# idx.upsert( +# vectors=[ +# {"id": "1", "values": embedding_values()}, +# {"id": 2, "values": embedding_values()}, +# ] +# ) diff --git a/tests/integration/data_asyncio/utils.py b/tests/integration/data_asyncio/utils.py new file mode 100644 index 00000000..8a7b5090 --- /dev/null +++ b/tests/integration/data_asyncio/utils.py @@ -0,0 +1,31 @@ +import random +import asyncio +from pinecone.grpc import PineconeGRPC as Pinecone + + +def build_asyncio_idx(host): + return Pinecone().AsyncioIndex(host=host) + + +def embedding_values(dimension=2): + return [random.random() for _ in range(dimension)] + + +async def poll_for_freshness(asyncio_idx, namespace, expected_count): + total_wait = 0 + delta = 2 + while True: + stats = await asyncio_idx.describe_index_stats() + if stats.namespaces.get(namespace, None) is not None: + if stats.namespaces[namespace].vector_count >= expected_count: + print( + f"Found {stats.namespaces[namespace].vector_count} vectors in namespace '{namespace}' after {total_wait} seconds" + ) + break + await asyncio.sleep(delta) + total_wait += delta + + if total_wait > 60: + raise TimeoutError( + f"Timed out waiting for vectors to appear in namespace '{namespace}'" + ) diff --git a/tests/unit_grpc/test_query_results_aggregator.py b/tests/unit_grpc/test_query_results_aggregator.py new file mode 100644 index 00000000..b4c78802 --- /dev/null +++ b/tests/unit_grpc/test_query_results_aggregator.py @@ -0,0 +1,553 @@ +from pinecone.grpc.query_results_aggregator import ( + QueryResultsAggregator, + QueryResultsAggregatorInvalidTopKError, + QueryResultsAggregregatorNotEnoughResultsError, +) +import random +import pytest + + +class TestQueryResultsAggregator: + def test_keeps_running_usage_total(self): + aggregator = QueryResultsAggregator(top_k=3) + + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "3", "score": 0.12, "values": []}, + {"id": "4", "score": 0.13, "values": []}, + {"id": "5", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.101, "values": []}, + {"id": "8", "score": 0.111, "values": []}, + {"id": "9", "score": 0.12, "values": []}, + {"id": "10", "score": 0.13, "values": []}, + {"id": "11", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "1" # 0.1 + assert results.matches[1].id == "7" # 0.101 + assert results.matches[2].id == "2" # 0.11 + + def test_inserting_duplicate_scores_stable_ordering(self): + aggregator = QueryResultsAggregator(top_k=5) + + results1 = { + "matches": [ + {"id": "1", "score": 0.11, "values": []}, + {"id": "3", "score": 0.11, "values": []}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "4", "score": 0.22, "values": []}, + {"id": "5", "score": 0.22, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.11, "values": []}, + {"id": "7", "score": 0.22, "values": []}, + {"id": "8", "score": 0.22, "values": []}, + {"id": "9", "score": 0.22, "values": []}, + {"id": "10", "score": 0.22, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 5 + assert results.matches[0].id == "1" # 0.11 + assert results.matches[0].namespace == "ns1" + assert results.matches[1].id == "3" # 0.11 + assert results.matches[1].namespace == "ns1" + assert results.matches[2].id == "2" # 0.11 + assert results.matches[2].namespace == "ns1" + assert results.matches[3].id == "6" # 0.11 + assert results.matches[3].namespace == "ns2" + assert results.matches[4].id == "4" # 0.22 + assert results.matches[4].namespace == "ns1" + + def test_correctly_handles_dotproduct_metric(self): + # For this index metric, the higher the score, the more similar the vectors are. + # We have to infer that we have this type of index by seeing whether scores are + # sorted in descending or ascending order. + aggregator = QueryResultsAggregator(top_k=3) + + desc_results1 = { + "matches": [ + {"id": "1", "score": 0.9, "values": []}, + {"id": "2", "score": 0.8, "values": []}, + {"id": "3", "score": 0.7, "values": []}, + {"id": "4", "score": 0.6, "values": []}, + {"id": "5", "score": 0.5, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(desc_results1) + + results2 = { + "matches": [ + {"id": "7", "score": 0.77, "values": []}, + {"id": "8", "score": 0.88, "values": []}, + {"id": "9", "score": 0.99, "values": []}, + {"id": "10", "score": 0.1010, "values": []}, + {"id": "11", "score": 0.1111, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "9" # 0.99 + assert results.matches[1].id == "1" # 0.9 + assert results.matches[2].id == "8" # 0.88 + + def test_still_correct_with_early_return(self): + aggregator = QueryResultsAggregator(top_k=5) + + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": []}, + {"id": "2", "score": 0.11, "values": []}, + {"id": "3", "score": 0.12, "values": []}, + {"id": "4", "score": 0.13, "values": []}, + {"id": "5", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + + results2 = { + "matches": [ + {"id": "6", "score": 0.10, "values": []}, + {"id": "7", "score": 0.101, "values": []}, + {"id": "8", "score": 0.12, "values": []}, + {"id": "9", "score": 0.13, "values": []}, + {"id": "10", "score": 0.14, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 5 + assert results.matches[0].id == "1" + assert results.matches[1].id == "6" + assert results.matches[2].id == "7" + assert results.matches[3].id == "2" + assert results.matches[4].id == "3" + + def test_still_correct_with_early_return_generated_nont_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=False) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=False) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=False) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=False) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=False) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=False) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + + def test_still_correct_with_early_return_generated_dotproduct(self): + aggregator = QueryResultsAggregator(top_k=1000) + matches1 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000) + ] + matches1.sort(key=lambda x: x["score"], reverse=True) + + matches2 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000) + ] + matches2.sort(key=lambda x: x["score"], reverse=True) + + matches3 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000) + ] + matches3.sort(key=lambda x: x["score"], reverse=True) + + matches4 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000) + ] + matches4.sort(key=lambda x: x["score"], reverse=True) + + matches5 = [ + {"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000) + ] + matches5.sort(key=lambda x: x["score"], reverse=True) + + results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}} + results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}} + results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}} + results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}} + results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}} + + aggregator.add_results(results1) + aggregator.add_results(results2) + aggregator.add_results(results3) + aggregator.add_results(results4) + aggregator.add_results(results5) + + merged_matches = matches1 + matches2 + matches3 + matches4 + matches5 + merged_matches.sort(key=lambda x: x["score"], reverse=True) + + results = aggregator.get_results() + assert results.usage.read_units == 25 + assert len(results.matches) == 1000 + assert results.matches[0].score == merged_matches[0]["score"] + assert results.matches[1].score == merged_matches[1]["score"] + assert results.matches[2].score == merged_matches[2]["score"] + assert results.matches[3].score == merged_matches[3]["score"] + assert results.matches[4].score == merged_matches[4]["score"] + + +class TestQueryResultsAggregatorOutputUX: + def test_can_interact_with_attributes(self): + aggregator = QueryResultsAggregator(top_k=2) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + {"id": "2", "score": 0.4}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + assert results.usage.read_units == 5 + assert results.matches[0].id == "1" + assert results.matches[0].namespace == "ns1" + assert results.matches[0].score == 0.3 + assert results.matches[0].values == [0.31, 0.32, 0.33, 0.34, 0.35, 0.36] + + def test_can_interact_like_dict(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + { + "id": "2", + "score": 0.4, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + assert results["usage"]["read_units"] == 5 + assert results["matches"][0]["id"] == "1" + assert results["matches"][0]["namespace"] == "ns1" + assert results["matches"][0]["score"] == 0.3 + + def test_can_print_empty_results_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=3) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_results_containing_None_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [ + {"id": "1", "score": 0.1}, + {"id": "2", "score": 0.11}, + {"id": "3", "score": 0.12}, + {"id": "4", "score": 0.13}, + {"id": "5", "score": 0.14}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_results_containing_short_vectors(self, capsys): + aggregator = QueryResultsAggregator(top_k=4) + results1 = { + "matches": [ + {"id": "1", "score": 0.1, "values": [0.31]}, + {"id": "2", "score": 0.11, "values": [0.31, 0.32]}, + {"id": "3", "score": 0.12, "values": [0.31, 0.32, 0.33]}, + {"id": "3", "score": 0.12, "values": [0.31, 0.32, 0.33, 0.34]}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + def test_can_print_complete_results_without_error(self, capsys): + aggregator = QueryResultsAggregator(top_k=2) + results1 = { + "matches": [ + { + "id": "1", + "score": 0.3, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, + { + "id": "2", + "score": 0.4, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": {"boolean": True, "nullish": None}, + }, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results = aggregator.get_results() + print(results) + capsys.readouterr() + + +class TestQueryAggregatorEdgeCases: + def test_topK_too_small(self): + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=0) + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=1) + + def test_matches_too_small(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator.add_results(results1) + + def test_empty_results(self): + aggregator = QueryResultsAggregator(top_k=3) + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 0 + assert len(results.matches) == 0 + + def test_empty_results_with_usage(self): + aggregator = QueryResultsAggregator(top_k=3) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_exactly_one_result(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results1) + + results2 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results2) + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns2" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns2" + assert results.matches[2].score == 0.2 + + def test_two_result_sets_with_single_result_errors(self): + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = { + "matches": [{"id": "2", "score": 0.01}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + def test_single_result_after_index_type_known_no_error(self): + aggregator = QueryResultsAggregator(top_k=3) + + results3 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns3", + } + aggregator.add_results(results3) + + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"} + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 15 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns3" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns3" + assert results.matches[2].score == 0.2 + + def test_all_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_some_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + results2 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns0", + } + aggregator.add_results(results2) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 20 + assert len(results.matches) == 2