Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 21, 2024
1 parent 3780924 commit ddef712
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 11 deletions.
31 changes: 31 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pinecone.grpc import PineconeGRPC, GRPCClientConfig

# Initialize a client. An API key must be passed, but the
# value does not matter.
pc = PineconeGRPC(api_key="test_api_key")

# Target the indexes. Use the host and port number along with disabling tls.
index1 = pc.Index(host="localhost:5081", grpc_config=GRPCClientConfig(secure=False))
index2 = pc.Index(host="localhost:5082", grpc_config=GRPCClientConfig(secure=False))

# You can now perform data plane operations with index1 and index2

dimension = 3


def upserts():
vectors = []
for i in range(0, 100):
vectors.append((f"vec{i}", [i] * dimension))

print(len(vectors))

index1.upsert(vectors=vectors, namespace="ns2")
index2.upsert(vectors=vectors, namespace="ns2")


upserts()
print(index1.describe_index_stats())

print(index1.query(id="vec1", top_k=2, namespace="ns2", include_values=True))
print(index1.query(id="vec1", top_k=10, namespace="", include_values=True))
16 changes: 16 additions & 0 deletions app2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pinecone.grpc import PineconeGRPC
from pinecone import Pinecone

pc = Pinecone(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a")
pcg = PineconeGRPC(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a")

index = pc.Index("jen2")
indexg = pcg.Index(name="jen2", use_asyncio=True)

# Rest call fails
# print(index.upsert(vectors=[("vec1", [1, 2])]))

# GRPC succeeds
print(indexg.upsert(vectors=[("vec1", [1, 2])]))

# print(index.fetch(ids=['vec1']))
62 changes: 62 additions & 0 deletions app3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
from pinecone.grpc import PineconeGRPC as Pinecone, Vector

import time
import random
import pandas as pd


# Enable gRPC tracing and verbosity for more detailed logs
# os.environ["GRPC_VERBOSITY"] = "DEBUG"
# os.environ["GRPC_TRACE"] = "all"


# Generate a large set of vectors (as an example)
def generate_vectors(num_vectors, dimension):
return [
Vector(id=f"vector_{i}", values=[random.random()] * dimension) for i in range(num_vectors)
]


def load_vectors():
df = pd.read_parquet("test_records_100k_dim1024.parquet")
df["values"] = df["values"].apply(lambda x: [float(v) for v in x])

vectors = [Vector(id=row.id, values=list(row.values)) for row in df.itertuples()]
return vectors


async def main():
# Create a semaphore to limit concurrency (e.g., max 5 concurrent requests)
s = time.time()
# all_vectors = load_vectors()
all_vectors = generate_vectors(1000, 1024)
f = time.time()
print(f"Loaded {len(all_vectors)} vectors in {f-s:.2f} seconds")

start_time = time.time()

# Same setup as before...
pc = Pinecone(api_key="b1cb8ba4-b3d1-458f-9c32-8dd10813459a")
index = pc.Index(
# index_host="jen2-dojoi3u.svc.aped-4627-b74a.pinecone.io"
host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io",
use_asyncio=True,
)

batch_size = 150
namespace = "asyncio-py7"
res = await index.upsert(
vectors=all_vectors, batch_size=batch_size, namespace=namespace, show_progress=True
)

print(res)

end_time = time.time()

total_time = end_time - start_time
print(f"All tasks completed in {total_time:.2f} seconds")
print(f"Namespace: {namespace}")


asyncio.run(main())
18 changes: 11 additions & 7 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,32 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from .utils import normalize_endpoint


class GRPCIndexBase(ABC):
"""
Base class for grpc-based interaction with Pinecone indexes
"""

_pool = None

def __init__(
self,
index_name: str,
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
use_asyncio: Optional[bool] = False,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()

self._endpoint_override = _endpoint_override

self.runner = GrpcRunner(
index_name=index_name, config=config, grpc_config=self.grpc_client_config
)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=use_asyncio
)
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)
Expand All @@ -46,9 +45,7 @@ def stub_class(self):
pass

def _endpoint(self):
grpc_host = self.config.host.replace("https://", "")
if ":" not in grpc_host:
grpc_host = f"{grpc_host}:443"
grpc_host = normalize_endpoint(self.config.host)
return self._endpoint_override if self._endpoint_override else grpc_host

def _gen_channel(self):
Expand Down Expand Up @@ -83,3 +80,10 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
self.close()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self.close()
return True
21 changes: 21 additions & 0 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture

from .config import GRPCClientConfig
from pinecone.config import Config
from grpc._channel import Channel


__all__ = ["GRPCIndex", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"]

Expand All @@ -53,6 +57,23 @@ class SparseVectorTypedDict(TypedDict):
class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""

def __init__(
self,
index_name: str,
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
):
super().__init__(
index_name=index_name,
config=config,
channel=channel,
grpc_config=grpc_config,
_endpoint_override=_endpoint_override,
use_asyncio=False,
)

@property
def stub_class(self):
return VectorServiceStub
Expand Down
100 changes: 100 additions & 0 deletions pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Optional, Union, List, Awaitable

from tqdm.asyncio import tqdm
from asyncio import Semaphore

from .vector_factory_grpc import VectorFactoryGRPC

from pinecone.core.grpc.protos.vector_service_pb2 import (
Vector as GRPCVector,
QueryVector as GRPCQueryVector,
UpsertRequest,
UpsertResponse,
SparseValues as GRPCSparseValues,
)
from .base import GRPCIndexBase
from pinecone import Vector as NonGRPCVector
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
from pinecone.utils import parse_non_empty_args

from .config import GRPCClientConfig
from pinecone.config import Config
from grpc._channel import Channel

__all__ = ["GRPCIndexAsyncio", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"]


class GRPCIndexAsyncio(GRPCIndexBase):
"""A client for interacting with a Pinecone index over GRPC with asyncio."""

def __init__(
self,
index_name: str,
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
):
super().__init__(
index_name=index_name,
config=config,
channel=channel,
grpc_config=grpc_config,
_endpoint_override=_endpoint_override,
use_asyncio=True,
)

@property
def stub_class(self):
return VectorServiceStub

async def upsert(
self,
vectors: Union[List[GRPCVector], List[NonGRPCVector], List[tuple], List[dict]],
namespace: Optional[str] = None,
batch_size: Optional[int] = None,
show_progress: bool = True,
**kwargs,
) -> Awaitable[UpsertResponse]:
timeout = kwargs.pop("timeout", None)
vectors = list(map(VectorFactoryGRPC.build, vectors))

if batch_size is None:
return await self._upsert_batch(vectors, namespace, timeout=timeout, **kwargs)

else:
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size must be a positive integer")

semaphore = Semaphore(25)
vector_batches = [
vectors[i : i + batch_size] for i in range(0, len(vectors), batch_size)
]
tasks = [
self._upsert_batch(
vectors=batch, namespace=namespace, timeout=100, semaphore=semaphore
)
for batch in vector_batches
]

return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches")

async def _upsert_batch(
self,
vectors: List[GRPCVector],
namespace: Optional[str],
timeout: Optional[int] = None,
semaphore: Optional[Semaphore] = None,
**kwargs,
) -> Awaitable[UpsertResponse]:
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
if semaphore is not None:
async with semaphore:
return await self.runner.run_asyncio(
self.stub.Upsert, request, timeout=timeout, **kwargs
)
else:
return await self.runner.run_asyncio(
self.stub.Upsert, request, timeout=timeout, **kwargs
)
9 changes: 7 additions & 2 deletions pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..control.pinecone import Pinecone
from ..config.config import ConfigBuilder
from .index_grpc import GRPCIndex
from .index_grpc_asyncio import GRPCIndexAsyncio


class PineconeGRPC(Pinecone):
Expand Down Expand Up @@ -47,7 +48,7 @@ class PineconeGRPC(Pinecone):
"""

def Index(self, name: str = "", host: str = "", **kwargs):
def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs):
"""
Target an index for data operations.
Expand Down Expand Up @@ -131,4 +132,8 @@ def Index(self, name: str = "", host: str = "", **kwargs):
proxy_url=self.config.proxy_url,
ssl_ca_certs=self.config.ssl_ca_certs,
)
return GRPCIndex(index_name=name, config=config, **kwargs)

if use_asyncio:
return GRPCIndexAsyncio(index_name=name, config=config, **kwargs)
else:
return GRPCIndex(index_name=name, config=config, **kwargs)
6 changes: 6 additions & 0 deletions pinecone/grpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def _generate_request_id() -> str:
return str(uuid.uuid4())


def normalize_endpoint(endpoint: str) -> str:
grpc_host = endpoint.replace("https://", "")
if ":" not in grpc_host:
grpc_host = f"{grpc_host}:443"


def dict_to_proto_struct(d: Optional[dict]) -> "Struct":
if not d:
d = {}
Expand Down
Loading

0 comments on commit ddef712

Please sign in to comment.