Skip to content

Commit

Permalink
feat(client): add rank endpoint (#695)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Apr 25, 2022
1 parent 5e1dd60 commit b727086
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/README-img/rerank-chart.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added .github/README-img/rerank.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,46 @@ Fun time! Note, unlike the previous example, here the input is an image and the
</table>


Intrigued? That's only scratching the surface of what CLIP-as-service is capable of. [Read our docs to learn more](https://clip-as-service.jina.ai).
### Rerank image-text matches via CLIP model

From `0.3.0` CLIP-as-service adds a new `/rerank` endpoint that re-ranks cross-modal matches according to their joint likelihood in CLIP model. For example, given an image Document with some predefined sentence matches as below:

```python
from clip_client import Client
from docarray import Document

c = Client(server='grpc://demo-cas.jina.ai:51000')
r = c.rerank(
[
Document(
uri='.github/README-img/rerank.png',
matches=[
Document(text=f'a photo of a {p}')
for p in (
'control room',
'lecture room',
'conference room',
'podium indoor',
'television studio',
)
],
)
]
)

print(r['@m', ['text', 'scores__clip_score__value']])
```

```text
[['a photo of a television studio', 'a photo of a conference room', 'a photo of a lecture room', 'a photo of a control room', 'a photo of a podium indoor'],
[0.9920725226402283, 0.006038925610482693, 0.0009973491542041302, 0.00078492151806131, 0.00010626466246321797]]
```

One can see now `a photo of a television studio` is ranked to the top with `clip_score` score at `0.992`. In practice, one can use this endpoint to re-rank the matching result from another search system, for improving the cross-modal search quality.

<img src="https://github.com/jina-ai/clip-as-service/blob/main/.github/README-img/rerank.png?raw=true" alt="Rerank endpoint image input" width="40%"><img src="https://github.com/jina-ai/clip-as-service/blob/main/.github/README-img/rerank-chart.svg?raw=true" alt="Rerank endpoint output" width="50%">

Intrigued? That's only scratching the surface of what CLIP-as-service is capable of. [Read our docs to learn more](https://clip-as-service.jina.ai).

<!-- start support-pitch -->
## Support
Expand Down
2 changes: 1 addition & 1 deletion client/clip_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '0.2.4'
__version__ = '0.3.0'

from .client import Client
83 changes: 82 additions & 1 deletion client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from urllib.parse import urlparse


if TYPE_CHECKING:
import numpy as np
from docarray import DocumentArray, Document
Expand Down Expand Up @@ -316,3 +315,85 @@ def _prepare_streaming(self, disable, total):
from docarray import DocumentArray

self._results = DocumentArray()

@staticmethod
def _prepare_single_doc(d: 'Document'):
if d.content_type in ('text', 'blob'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
return d
elif d.tensor is not None:
return d
else:
raise TypeError(f'unsupported input type {d!r} {d.content_type}')

@staticmethod
def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)
if not _get(d):
raise ValueError(f'`.rerank()` requires every doc to have `.{_source}`')
d = Client._prepare_single_doc(d)
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

def _iter_rank_docs(
self, content, _source='matches'
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

self._return_plain = True

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, Document):
yield self._prepare_rank_doc(c, _source)
else:
raise TypeError(f'unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

def _get_rank_payload(self, content, kwargs):
return dict(
on='/rerank',
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)

def rerank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rerank image-text matches according to the server CLIP model.
Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e.
text/image; this method reranks the matches according to the CLIP model.
Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score.
More details can be found in: https://github.com/openai/CLIP#usage
:param docs: the input Documents
:return: the reranked Documents in a DocumentArray.
"""
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(docs),
)
with self._pbar:
self._client.post(
**self._get_rank_payload(docs, kwargs), on_done=self._gather_result
)
return self._results
68 changes: 68 additions & 0 deletions docs/user-guides/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,74 @@ asyncio.run(main())

The final time cost will be less than `3s + time(t2)`.

## Reranking

```{tip}
This feature is only available with `clip_server>=0.3.0` and the server is running with PyTorch backend.
```

One can also rerank cross-modal matches via {meth}`~clip_client.client.Client.rerank`. First construct a cross-modal Document where the root contains an image and `.matches` contain sentences to rerank. One can also construct text-to-image rerank as below:

````{tab} Given image, rerank sentences
```python
from docarray import Document
d = Document(
uri='.github/README-img/rerank.png',
matches=[
Document(text=f'a photo of a {p}')
for p in (
'control room',
'lecture room',
'conference room',
'podium indoor',
'television studio',
)
],
)
```
````

````{tab} Given sentence, rerank images
```python
from docarray import Document
d = Document(
text='a photo of conference room',
matches=[
Document(uri='.github/README-img/4.png'),
Document(uri='.github/README-img/9.png'),
Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'),
],
)
```
````



Then call `rerank`, you can feed it with multiple Documents as a list:

```python
from clip_client import Client

c = Client(server='grpc://demo-cas.jina.ai:51000')
r = c.rerank([d])

print(r['@m', ['text', 'scores__clip_score__value']])
```

Finally, in the return you can observe the matches are re-ranked according to `.scores['clip_score']`:

```text
[['a photo of a television studio', 'a photo of a conference room', 'a photo of a lecture room', 'a photo of a control room', 'a photo of a podium indoor'],
[0.9920725226402283, 0.006038925610482693, 0.0009973491542041302, 0.00078492151806131, 0.00010626466246321797]]
```


(profiling)=
## Profiling

Expand Down
2 changes: 1 addition & 1 deletion server/clip_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.4'
__version__ = '0.3.0'
9 changes: 4 additions & 5 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import warnings
from multiprocessing.pool import ThreadPool, Pool
from typing import List, Tuple, Optional
import numpy as np
import onnxruntime as ort

from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger

from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
Expand Down Expand Up @@ -33,7 +33,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.logger = JinaLogger(self.__class__.__name__)

self._preprocess_blob = clip._transform_blob(_SIZE[name])
self._preprocess_tensor = clip._transform_ndarray(_SIZE[name])
Expand Down Expand Up @@ -67,13 +66,13 @@ def __init__(
)

if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
'OMP_NUM_THREADS' not in os.environ
and hasattr(self.runtime_args, 'replicas')
):
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
self.logger.warning(
warnings.warn(
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)
Expand Down Expand Up @@ -117,7 +116,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
elif d.uri:
_img_da.append(d)
else:
self.logger.warning(
warnings.warn(
f'The content of document {d.id} is empty, cannot be processed'
)

Expand Down
47 changes: 31 additions & 16 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import warnings
from multiprocessing.pool import ThreadPool
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Dict

import numpy as np
from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger
from jina import Executor, requests, DocumentArray, Document

from clip_server.model import clip

Expand All @@ -20,7 +20,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.logger = JinaLogger(self.__class__.__name__)

import torch

Expand All @@ -30,13 +29,13 @@ def __init__(
self._device = device

if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
'OMP_NUM_THREADS' not in os.environ
and hasattr(self.runtime_args, 'replicas')
):
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
self.logger.warning(
warnings.warn(
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)
Expand Down Expand Up @@ -82,26 +81,31 @@ def _split_img_txt_da(d, _img_da, _txt_da):
elif d.uri:
_img_da.append(d)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', **kwargs):
@requests(on='/rerank')
async def rerank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
self._split_img_txt_da(d, _img_da, _txt_da)

for c in d.chunks:
for c in _get(d):
self._split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
raise ValueError(
'chunks must be all in same modality, either all images or all text'
f'`d.{_source}` must be all in same modality, either all images or all text'
)
elif not _img_da or not _txt_da:
raise ValueError(
'root and chunks must be in different modality, one is image one is text'
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
elif len(d.chunks) <= 1:
raise ValueError('must have more than one chunks to rank over chunks')
else:
_img_da = self._preproc_image(_img_da)
_txt_da, texts = self._preproc_text(_txt_da)
Expand All @@ -120,10 +124,21 @@ async def rank(self, docs: 'DocumentArray', **kwargs):
elif len(_txt_da) == 1:
probs = probs_text

for c, v in zip(d.chunks, probs):
c.scores['clip-rank'].value = v

_txt_da.texts = texts
_img_da.tensors = None
_txt_da.tensors = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
setattr(
d,
_source,
sorted(
_get(d),
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

os.environ['OMP_NUM_THREADS'] = '1'
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,12 @@ def make_flow(port_generator, request):
f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
yield f


@pytest.fixture(scope='session', params=['torch'])
def make_torch_flow(port_generator, request):
from clip_server.executors.clip_torch import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
yield f
Loading

0 comments on commit b727086

Please sign in to comment.