Skip to content

Commit

Permalink
feat: add ranker endpoint for all backends (#707)
Browse files Browse the repository at this point in the history
* feat: add ranker endpoint for onnx

* fix: import

* feat: add ranker endpoint for trt

* fix: gpu test
  • Loading branch information
numb3r3 authored May 4, 2022
1 parent 618dbdb commit f66b145
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 20 deletions.
80 changes: 79 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Optional
from typing import Optional, Dict
import numpy as np
import onnxruntime as ort

from jina import Executor, requests, DocumentArray
Expand All @@ -19,6 +20,7 @@ def __init__(
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 16,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -29,6 +31,7 @@ def __init__(
self._minibatch_size = minibatch_size

self._model = CLIPOnnxModel(name)
self._logit_scale = logit_scale

import torch

Expand Down Expand Up @@ -74,6 +77,79 @@ def __init__(

self._model.start_sessions(sess_options=sess_options, providers=providers)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
raise ValueError(
f'`d.{_source}` must be all in same modality, either all images or all text'
)
elif not _img_da or not _txt_da:
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
else:
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)

# normalized features
image_features = _img_da.embeddings / np.linalg.norm(
_img_da.embeddings, axis=1, keepdims=True
)
text_features = _txt_da.embeddings / np.linalg.norm(
_txt_da.embeddings, axis=1, keepdims=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * np.matmul(
image_features, text_features.T
)
logits_per_text = logits_per_image.T

def numpy_softmax(z):
s = np.max(z, axis=1)
s = s[:, np.newaxis]
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div

if len(_img_da) == 1:
probs = numpy_softmax(logits_per_image)[0]
elif len(_txt_da) == 1:
probs = numpy_softmax(logits_per_text)[0]

# drop embeddings
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
setattr(
d,
_source,
sorted(
_get(d),
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
Expand Down Expand Up @@ -104,3 +180,5 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# drop tensors
docs.tensors = None

return docs
18 changes: 10 additions & 8 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,18 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

probs_image = (
logits_per_image.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
probs_text = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
if len(_img_da) == 1:
probs = probs_image
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
elif len(_txt_da) == 1:
probs = probs_text
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)

_img_da.embeddings = None
_txt_da.embeddings = None
Expand Down
82 changes: 80 additions & 2 deletions server/clip_server/executors/clip_trt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict
from multiprocessing.pool import ThreadPool
from functools import partial
import numpy as np
from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger

from clip_server.model import clip
from clip_server.model.clip_trt import CLIPTensorRTModel
Expand All @@ -16,17 +16,19 @@ def __init__(
device: str = 'cuda',
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
self.logger = JinaLogger(self.__class__.__name__)

self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size
self._device = device

self._logit_scale = logit_scale

import torch

assert self._device.startswith('cuda'), (
Expand All @@ -42,6 +44,80 @@ def __init__(

self._model.start_engines()

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
raise ValueError(
f'`d.{_source}` must be all in same modality, either all images or all text'
)
elif not _img_da or not _txt_da:
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
else:
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings)

# normalized features
image_features = _img_da.embeddings / _img_da.embeddings.norm(
dim=-1, keepdim=True
)
text_features = _txt_da.embeddings / _txt_da.embeddings.norm(
dim=-1, keepdim=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

if len(_img_da) == 1:
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
elif len(_txt_da) == 1:
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)

_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
setattr(
d,
_source,
sorted(
_get(d),
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
Expand Down Expand Up @@ -87,3 +163,5 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# drop tensors
docs.tensors = None

return docs
21 changes: 12 additions & 9 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import pytest
from clip_client import Client
from clip_server.executors.clip_torch import CLIPEncoder
from clip_server.executors.clip_torch import CLIPEncoder as TorchCLIPEncoder
from clip_server.executors.clip_onnx import CLIPEncoder as ONNXCLILPEncoder
from docarray import DocumentArray, Document


@pytest.mark.asyncio
async def test_torch_executor_rank_img2texts():
ce = CLIPEncoder()
@pytest.mark.parametrize('encoder_class', [TorchCLIPEncoder, ONNXCLILPEncoder])
async def test_torch_executor_rank_img2texts(encoder_class):
ce = encoder_class()

da = DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
Expand All @@ -25,8 +27,9 @@ async def test_torch_executor_rank_img2texts():


@pytest.mark.asyncio
async def test_torch_executor_rank_text2imgs():
ce = CLIPEncoder()
@pytest.mark.parametrize('encoder_class', [TorchCLIPEncoder, ONNXCLILPEncoder])
async def test_torch_executor_rank_text2imgs(encoder_class):
ce = encoder_class()
db = DocumentArray(
[Document(text='hello, world!'), Document(text='goodbye, world!')]
)
Expand Down Expand Up @@ -61,8 +64,8 @@ async def test_torch_executor_rank_text2imgs():
),
],
)
def test_docarray_inputs(make_torch_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}')
def test_docarray_inputs(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.rank([d])
assert isinstance(r, DocumentArray)
rv = r['@m', 'scores__clip_score__value']
Expand Down Expand Up @@ -90,8 +93,8 @@ def test_docarray_inputs(make_torch_flow, d):
],
)
@pytest.mark.asyncio
async def test_async_arank(make_torch_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}')
async def test_async_arank(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = await c.arank([d])
assert isinstance(r, DocumentArray)
rv = r['@m', 'scores__clip_score__value']
Expand Down
30 changes: 30 additions & 0 deletions tests/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,33 @@ def test_docarray_inputs(make_trt_flow, inputs):
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape


@pytest.mark.gpu
@pytest.mark.asyncio
@pytest.mark.parametrize(
'd',
[
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
matches=[Document(text='hello, world'), Document(text='goodbye, world')],
),
Document(
text='hello, world',
matches=[
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
],
),
],
)
async def test_async_arank(make_trt_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_trt_flow.port}')
r = await c.arank([d])
assert isinstance(r, DocumentArray)
rv = r['@m', 'scores__clip_score__value']
for v in rv:
assert v is not None
assert v > 0

0 comments on commit f66b145

Please sign in to comment.