Skip to content

Commit

Permalink
feat: support clip retrieval (#816)
Browse files Browse the repository at this point in the history
* feat: add demo

* fix: fix _get_post_payload

* fix: fix client post

* fix: remove workspace folder

* fix: fix parameters

* fix: revert yml change

* feat: add stramlit

* feat: add index an search api

* fix: add test for index and search

* fix: ci

* fix: address comments

Co-authored-by: Ziniu Yu <[email protected]>

* fix: remove streamlit script

* fix: remove search flow

Co-authored-by: jemmyshin <[email protected]>
Co-authored-by: Ziniu Yu <[email protected]>
  • Loading branch information
3 people authored Sep 9, 2022
1 parent 47144c2 commit a07a521
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
pip install --no-cache-dir "client/[test]"
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
- name: Test
id: test
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jobs:
pip install --no-cache-dir "client/[test]"
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
- name: Test
id: test
run: |
Expand Down
289 changes: 280 additions & 9 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import mimetypes
import os
import time
import types
import warnings
from typing import (
overload,
Expand Down Expand Up @@ -118,9 +117,13 @@ def encode(self, content, **kwargs):
)
results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
model_name = parameters.pop('model_name', '') if parameters else ''
self._client.post(
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
parameters=parameters,
)

for c in content:
Expand Down Expand Up @@ -199,10 +202,7 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
)

def _get_post_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on=f'/encode/{model_name}'.rstrip('/'),
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
Expand Down Expand Up @@ -273,6 +273,7 @@ async def aencode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'np.ndarray':
...

Expand All @@ -283,6 +284,7 @@ async def aencode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'DocumentArray':
...

Expand All @@ -296,8 +298,13 @@ async def aencode(self, content, **kwargs):

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
model_name = parameters.get('model_name', '') if parameters else ''

async for da in self._async_client.post(
**self._get_post_payload(content, kwargs)
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, kwargs),
parameters=parameters,
):
if not results:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -405,10 +412,7 @@ def _iter_rank_docs(
)

def _get_rank_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on=f'/rank/{model_name}'.rstrip('/'),
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
Expand Down Expand Up @@ -436,9 +440,13 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
)
results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
model_name = parameters.get('model_name', '') if parameters else ''
self._client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, kwargs),
on_done=partial(self._gather_result, results=results),
parameters=parameters,
)
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))
Expand All @@ -454,8 +462,12 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
)
results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
model_name = parameters.get('model_name', '') if parameters else ''
async for da in self._async_client.post(
**self._get_rank_payload(docs, kwargs)
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, kwargs),
parameters=parameters,
):
if not results:
self._pbar.start_task(self._r_task)
Expand All @@ -474,3 +486,262 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return results

@overload
def index(
self,
content: Iterable[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
):
"""Index the images or texts where their embeddings are computed by the server CLIP model.
Each image and text must be represented as a string. The following strings are acceptable:
- local image filepath, will be considered as an image
- remote image http/https, will be considered as an image
- a dataURI, will be considered as an image
- plain text, will be considered as a sentence
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...

@overload
def index(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'DocumentArray':
"""Index the images or texts where their embeddings are computed by the server CLIP model.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...

def index(self, content, **kwargs):
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead'
)

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
self._client.post(
on='/index',
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
parameters=parameters,
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

@overload
async def aindex(
self,
content: Iterator[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
):
...

@overload
async def aindex(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
):
...

async def aindex(self, content, **kwargs):
from rich import filesize

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()
with self._pbar:
async for da in self._async_client.post(
on='/index',
**self._get_post_payload(content, kwargs),
parameters=kwargs.pop('parameters', None),
):
if not results:
self._pbar.start_task(self._r_task)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))
)
),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

@overload
def search(
self,
content: Iterable[str],
*,
limit: int = 20,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
If the input is a string, will use this string as query. If the input is a ``Document``,
will use this ``Document`` as query.
:param content: list of queries.
:param limit: the number of results to return.
:param batch_size: the number of elements in each request when sending ``content``.
:param show_progress: if set, show a progress bar.
:param parameters: parameters passed to search function.
"""
...

@overload
def search(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
limit: int = 20,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
If the input is a string, will use this string as query. If the input is a ``Document``,
will use this ``Document`` as query.
:param content: list of queries.
:param limit: the number of results to return.
:param batch_size: the number of elements in each request when sending ``content``.
:param show_progress: if set, show a progress bar.
:param parameters: parameters passed to search function.
"""
...

def search(self, content, **kwargs) -> 'DocumentArray':
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead'
)

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = kwargs.get('limit')

self._client.post(
on='/search',
**self._get_post_payload(content, kwargs),
parameters=parameters,
on_done=partial(self._gather_result, results=results),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

@overload
async def asearch(
self,
content: Iterator[str],
*,
limit: int = 20,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
):
...

@overload
async def asearch(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
limit: int = 20,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
):
...

async def asearch(self, content, **kwargs):
from rich import filesize

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()

with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = kwargs.get('limit')

async for da in self._async_client.post(
on='/search',
**self._get_post_payload(content, kwargs),
parameters=parameters,
):
if not results:
self._pbar.start_task(self._r_task)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))
)
),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
'tensorrt': ['nvidia-tensorrt'],
'transformers': ['transformers>=4.16.2'],
'search': ['annlite>=0.3.10'],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down
Loading

0 comments on commit a07a521

Please sign in to comment.