Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP on threadpool impl of query_namespaces #405

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions pinecone/core/openapi/shared/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
import typing
from urllib.parse import quote
from urllib3.fields import RequestField
import time
import random

def retry_api_call(
func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5
):
attempts = 0
while attempts < retries:
try:
return func(*args, **kwargs) # Attempt to call __call_api

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a python expert, but isn;t there some library that provides a decorator to retry functions? seems like this should be a solved problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing in the standard library afaik, and I'm very hesitant to add third-party dependencies because of the overall dependency hell situation in python.

except Exception as e:
attempts += 1
if attempts >= retries:
print(f"API call failed after {attempts} attempts: {e}")
raise # Re-raise exception if retries are exhausted
sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter)
# print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}")
time.sleep(sleep_time)


from pinecone.core.openapi.shared import rest
Expand Down Expand Up @@ -397,25 +415,32 @@ def call_api(
)

return self.pool.apply_async(
self.__call_api,
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
retry_api_call,
args=(
self.__call_api, # Pass the API call function as the first argument
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
{}, # empty kwargs dictionary
3, # retries
1, # backoff time
0.5 # jitter
)
)

def request(
Expand Down
90 changes: 88 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from pinecone.core.openapi.data.api.data_plane_api import DataPlaneApi
from ..utils import setup_openapi_client, parse_non_empty_args
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from multiprocessing.pool import ApplyResult

__all__ = [
"Index",
Expand Down Expand Up @@ -361,7 +363,7 @@ def query(
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
) -> Union[QueryResponse, ApplyResult]:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Expand Down Expand Up @@ -403,6 +405,39 @@ def query(
and namespace name.
"""

response = self._query(
*args,
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)

if kwargs.get("async_req", False):
return response
else:
return parse_query_response(response)

def _query(
self,
*args,
top_k: int,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespace: Optional[str] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
if len(args) > 0:
raise ValueError(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
Expand Down Expand Up @@ -435,7 +470,58 @@ def query(
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
)
return parse_query_response(response)
return response

@validate_and_convert_errors
def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryNamespacesResults:
if len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

# The caller may only want the top_k=1 result across all queries,
# but we need to get at least 2 results from each query in order to
# aggregate them correctly. So we'll temporarily set topK to 2 for the
# subqueries, and then we'll take the topK=1 results from the aggregated
# results.
overall_topk = top_k if top_k is not None else 10

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this magic number 10?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a default value. Could be anything, but the API requires a value to be passed.

aggregator = QueryResultsAggregator(top_k=overall_topk)
subquery_topk = overall_topk if overall_topk > 2 else 2

target_namespaces = set(namespaces) # dedup namespaces
async_results = [
self.query(
vector=vector,
namespace=ns,
top_k=subquery_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
**kwargs,
)
for ns in target_namespaces
]

for result in async_results:
response = result.get()
aggregator.add_results(response)

final_results = aggregator.get_results()
return final_results

@validate_and_convert_errors
def update(
Expand Down
2 changes: 1 addition & 1 deletion pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
parse_sparse_values_arg,
)
from .vector_factory_grpc import VectorFactoryGRPC
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from ..data.query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults


class GRPCIndexAsyncio(GRPCIndexBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pinecone.grpc.query_results_aggregator import (
from pinecone.data.query_results_aggregator import (
QueryResultsAggregator,
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
Expand Down
Loading