diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index dc5b5b388..0a7e13059 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -43,6 +43,7 @@ jobs: pip install --no-cache-dir "client/[test]" pip install --no-cache-dir "server/[onnx]" pip install --no-cache-dir "server/[transformers]" + pip install --no-cache-dir "server/[search]" - name: Test id: test run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 039d28793..65e32dbb8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -115,6 +115,7 @@ jobs: pip install --no-cache-dir "client/[test]" pip install --no-cache-dir "server/[onnx]" pip install --no-cache-dir "server/[transformers]" + pip install --no-cache-dir "server/[search]" - name: Test id: test run: | diff --git a/client/clip_client/client.py b/client/clip_client/client.py index f7a0a736d..f30d7e5da 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -1,7 +1,6 @@ import mimetypes import os import time -import types import warnings from typing import ( overload, @@ -118,9 +117,13 @@ def encode(self, content, **kwargs): ) results = DocumentArray() with self._pbar: + parameters = kwargs.pop('parameters', None) + model_name = parameters.pop('model_name', '') if parameters else '' self._client.post( + on=f'/encode/{model_name}'.rstrip('/'), **self._get_post_payload(content, kwargs), on_done=partial(self._gather_result, results=results), + parameters=parameters, ) for c in content: @@ -199,10 +202,7 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: ) def _get_post_payload(self, content, kwargs): - parameters = kwargs.get('parameters', {}) - model_name = parameters.get('model', '') payload = dict( - on=f'/encode/{model_name}'.rstrip('/'), inputs=self._iter_doc(content), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, @@ -273,6 +273,7 @@ async def aencode( *, batch_size: Optional[int] = None, show_progress: bool = False, + parameters: Optional[dict] = None, ) -> 'np.ndarray': ... @@ -283,6 +284,7 @@ async def aencode( *, batch_size: Optional[int] = None, show_progress: bool = False, + parameters: Optional[dict] = None, ) -> 'DocumentArray': ... @@ -296,8 +298,13 @@ async def aencode(self, content, **kwargs): results = DocumentArray() with self._pbar: + parameters = kwargs.pop('parameters', None) + model_name = parameters.get('model_name', '') if parameters else '' + async for da in self._async_client.post( - **self._get_post_payload(content, kwargs) + on=f'/encode/{model_name}'.rstrip('/'), + **self._get_post_payload(content, kwargs), + parameters=parameters, ): if not results: self._pbar.start_task(self._r_task) @@ -405,10 +412,7 @@ def _iter_rank_docs( ) def _get_rank_payload(self, content, kwargs): - parameters = kwargs.get('parameters', {}) - model_name = parameters.get('model', '') payload = dict( - on=f'/rank/{model_name}'.rstrip('/'), inputs=self._iter_rank_docs( content, _source=kwargs.get('source', 'matches') ), @@ -436,9 +440,13 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': ) results = DocumentArray() with self._pbar: + parameters = kwargs.pop('parameters', None) + model_name = parameters.get('model_name', '') if parameters else '' self._client.post( + on=f'/rank/{model_name}'.rstrip('/'), **self._get_rank_payload(docs, kwargs), on_done=partial(self._gather_result, results=results), + parameters=parameters, ) for d in docs: self._reset_rank_doc(d, _source=kwargs.get('source', 'matches')) @@ -454,8 +462,12 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': ) results = DocumentArray() with self._pbar: + parameters = kwargs.pop('parameters', None) + model_name = parameters.get('model_name', '') if parameters else '' async for da in self._async_client.post( - **self._get_rank_payload(docs, kwargs) + on=f'/rank/{model_name}'.rstrip('/'), + **self._get_rank_payload(docs, kwargs), + parameters=parameters, ): if not results: self._pbar.start_task(self._r_task) @@ -474,3 +486,262 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': self._reset_rank_doc(d, _source=kwargs.get('source', 'matches')) return results + + @overload + def index( + self, + content: Iterable[str], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[Dict] = None, + ): + """Index the images or texts where their embeddings are computed by the server CLIP model. + + Each image and text must be represented as a string. The following strings are acceptable: + - local image filepath, will be considered as an image + - remote image http/https, will be considered as an image + - a dataURI, will be considered as an image + - plain text, will be considered as a sentence + :param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string. + :param batch_size: the number of elements in each request when sending ``content`` + :param show_progress: if set, show a progress bar + :param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models + :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` + """ + ... + + @overload + def index( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[dict] = None, + ) -> 'DocumentArray': + """Index the images or texts where their embeddings are computed by the server CLIP model. + + :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. + :param batch_size: the number of elements in each request when sending ``content`` + :param show_progress: if set, show a progress bar + :param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models + :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` + """ + ... + + def index(self, content, **kwargs): + if isinstance(content, str): + raise TypeError( + f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead' + ) + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + with self._pbar: + parameters = kwargs.pop('parameters', None) + self._client.post( + on='/index', + **self._get_post_payload(content, kwargs), + on_done=partial(self._gather_result, results=results), + parameters=parameters, + ) + + for c in content: + if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False): + c.pop('blob') + + return self._unboxed_result(results) + + @overload + async def aindex( + self, + content: Iterator[str], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[Dict] = None, + ): + ... + + @overload + async def aindex( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[dict] = None, + ): + ... + + async def aindex(self, content, **kwargs): + from rich import filesize + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + with self._pbar: + async for da in self._async_client.post( + on='/index', + **self._get_post_payload(content, kwargs), + parameters=kwargs.pop('parameters', None), + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) + + for c in content: + if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False): + c.pop('blob') + + return self._unboxed_result(results) + + @overload + def search( + self, + content: Iterable[str], + *, + limit: int = 20, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[Dict] = None, + ) -> 'DocumentArray': + """Search for top k results for given query string or ``Document``. + + If the input is a string, will use this string as query. If the input is a ``Document``, + will use this ``Document`` as query. + + :param content: list of queries. + :param limit: the number of results to return. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + :param parameters: parameters passed to search function. + """ + ... + + @overload + def search( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + limit: int = 20, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[dict] = None, + ) -> 'DocumentArray': + """Search for top k results for given query string or ``Document``. + + If the input is a string, will use this string as query. If the input is a ``Document``, + will use this ``Document`` as query. + + :param content: list of queries. + :param limit: the number of results to return. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + :param parameters: parameters passed to search function. + """ + ... + + def search(self, content, **kwargs) -> 'DocumentArray': + if isinstance(content, str): + raise TypeError( + f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead' + ) + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + with self._pbar: + parameters = kwargs.pop('parameters', {}) + parameters['limit'] = kwargs.get('limit') + + self._client.post( + on='/search', + **self._get_post_payload(content, kwargs), + parameters=parameters, + on_done=partial(self._gather_result, results=results), + ) + + for c in content: + if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False): + c.pop('blob') + + return self._unboxed_result(results) + + @overload + async def asearch( + self, + content: Iterator[str], + *, + limit: int = 20, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[Dict] = None, + ): + ... + + @overload + async def asearch( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + limit: int = 20, + batch_size: Optional[int] = None, + show_progress: bool = False, + parameters: Optional[dict] = None, + ): + ... + + async def asearch(self, content, **kwargs): + from rich import filesize + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + + with self._pbar: + parameters = kwargs.pop('parameters', {}) + parameters['limit'] = kwargs.get('limit') + + async for da in self._async_client.post( + on='/search', + **self._get_post_payload(content, kwargs), + parameters=parameters, + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) + + for c in content: + if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False): + c.pop('blob') + + return self._unboxed_result(results) diff --git a/server/setup.py b/server/setup.py index 0b1493b9d..78275b0c1 100644 --- a/server/setup.py +++ b/server/setup.py @@ -58,6 +58,7 @@ + (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []), 'tensorrt': ['nvidia-tensorrt'], 'transformers': ['transformers>=4.16.2'], + 'search': ['annlite>=0.3.10'], }, classifiers=[ 'Development Status :: 5 - Production/Stable', diff --git a/tests/conftest.py b/tests/conftest.py index 280970a27..0726beec3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,3 +56,22 @@ def make_trt_flow(port_generator, request): f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder) with f: yield f + + +@pytest.fixture(params=['torch']) +def make_search_flow(tmpdir, port_generator, request): + from clip_server.executors.clip_torch import CLIPEncoder + from annlite.executor import AnnLiteIndexer + + f = ( + Flow(port=port_generator()) + .add(name=request.param, uses=CLIPEncoder) + .add( + name='annlite', + uses=AnnLiteIndexer, + workspace=tmpdir, + uses_with={'n_dim': 512}, + ) + ) + with f: + yield f diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 000000000..52dd15e4c --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,82 @@ +import os + +import numpy as np +import pytest +from docarray import DocumentArray, Document + +from clip_client import Client + + +@pytest.mark.parametrize( + 'inputs', + [ + [Document(text='hello, world'), Document(text='goodbye, world')], + DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + lambda: (Document(text='hello, world') for _ in range(10)), + DocumentArray( + [ + Document(uri='https://docarray.jina.ai/_static/favicon.png'), + Document( + uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + ), + Document(text='hello, world'), + Document( + uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + ).load_uri_to_image_tensor(), + ] + ), + DocumentArray.from_files( + f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + ), + ], +) +@pytest.mark.parametrize('limit', [1, 2]) +def test_index_search(make_search_flow, inputs, limit): + c = Client(server=f'grpc://0.0.0.0:{make_search_flow.port}') + + r = c.index(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert r.embeddings.shape[1] == 512 + + r = c.search(inputs if not callable(inputs) else inputs(), limit=limit) + assert isinstance(r, DocumentArray) + for d in r: + assert len(d.matches) == limit + + +@pytest.mark.parametrize( + 'inputs', + [ + [Document(text='hello, world'), Document(text='goodbye, world')], + DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + lambda: (Document(text='hello, world') for _ in range(10)), + DocumentArray( + [ + Document(uri='https://docarray.jina.ai/_static/favicon.png'), + Document( + uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + ), + Document(text='hello, world'), + Document( + uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + ).load_uri_to_image_tensor(), + ] + ), + DocumentArray.from_files( + f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + ), + ], +) +@pytest.mark.parametrize('limit', [1, 2]) +@pytest.mark.asyncio +async def test_async_index_search(make_search_flow, inputs, limit): + c = Client(server=f'grpc://0.0.0.0:{make_search_flow.port}') + r = await c.aindex(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + + assert r.embeddings.shape[1] == 512 + + r = await c.asearch(inputs if not callable(inputs) else inputs(), limit=limit) + assert isinstance(r, DocumentArray) + for d in r: + assert len(d.matches) == limit