diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index a41b28f1d..8e1822a59 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -3,19 +3,18 @@ from functools import partial from multiprocessing.pool import ThreadPool from typing import Optional, Dict -import numpy as np -import onnxruntime as ort +import onnxruntime as ort from jina import Executor, requests, DocumentArray -from clip_server.model import clip -from clip_server.model.clip_onnx import CLIPOnnxModel from clip_server.executors.helper import ( split_img_txt_da, preproc_image, preproc_text, - numpy_softmax, + set_rank, ) +from clip_server.model import clip +from clip_server.model.clip_onnx import CLIPOnnxModel class CLIPEncoder(Executor): @@ -35,8 +34,6 @@ def __init__( self._minibatch_size = minibatch_size self._model = CLIPOnnxModel(name) - # Note: hard coded here since all the pretrained clip model use the same logit_scale parameter - self._logit_scale = np.exp(4.60517) import torch @@ -84,77 +81,12 @@ def __init__( @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - _source = parameters.get('source', 'matches') + _source = parameters.get('source', '@m') - for d in docs: - _img_da = DocumentArray() - _txt_da = DocumentArray() - split_img_txt_da(d, _img_da, _txt_da) - - candidates = getattr(d, _source) - - for c in candidates: - split_img_txt_da(c, _img_da, _txt_da) + await self.encode(docs) + await self.encode(docs[_source]) - if len(_img_da) != 1 and len(_txt_da) != 1: - raise ValueError( - f'`d.{_source}` must be all in same modality, either all images or all text' - ) - elif not _img_da or not _txt_da: - raise ValueError( - f'`d` and `d.{_source}` must be in different modality, one is image one is text' - ) - elif len(candidates) <= 1: - raise ValueError( - f'`d.{_source}` must have more than one Documents to do ranking' - ) - else: - _img_da = await self.encode(_img_da) - _txt_da = await self.encode(_txt_da) - - # normalized features - image_features = _img_da.embeddings / np.linalg.norm( - _img_da.embeddings, axis=1, keepdims=True - ) - text_features = _txt_da.embeddings / np.linalg.norm( - _txt_da.embeddings, axis=1, keepdims=True - ) - - # paired cosine similarity - scores_per_text = np.matmul(image_features, text_features.T) - scores_per_image = scores_per_text.T - - if len(_img_da) == 1: - cosine_scores = scores_per_text - elif len(_txt_da) == 1: - cosine_scores = scores_per_image - - softmax_scores = numpy_softmax(self._logit_scale * cosine_scores) - - # squeeze scores - cosine_scores = cosine_scores[0] - softmax_scores = softmax_scores[0] - - # drop embeddings - _img_da.embeddings = None - _txt_da.embeddings = None - - for c, p, o in zip(candidates, softmax_scores, cosine_scores): - c.scores['clip_score'].value = p - c.scores['clip_score'].op_name = 'softmax' - - c.scores['clip_score_cosine'].value = o - c.scores['clip_score_cosine'].op_name = 'cosine' - - setattr( - d, - _source, - sorted( - candidates, - key=lambda _m: _m.scores['clip_score'].value, - reverse=True, - ), - ) + set_rank(docs, _source) @requests async def encode(self, docs: 'DocumentArray', **kwargs): diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 5615249bd..98fd09cbf 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -1,17 +1,18 @@ -from typing import Dict -from multiprocessing.pool import ThreadPool from functools import partial +from multiprocessing.pool import ThreadPool +from typing import Dict + import numpy as np from jina import Executor, requests, DocumentArray -from clip_server.model import clip -from clip_server.model.clip_trt import CLIPTensorRTModel from clip_server.executors.helper import ( split_img_txt_da, preproc_image, preproc_text, - numpy_softmax, + set_rank, ) +from clip_server.model import clip +from clip_server.model.clip_trt import CLIPTensorRTModel class CLIPEncoder(Executor): @@ -46,82 +47,14 @@ def __init__( self._model.start_engines() - # Note: hard coded here since all the pretrained clip model use the same logit_scale parameter - self._logit_scale = np.exp(4.60517) - @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - _source = parameters.get('source', 'matches') + _source = parameters.get('source', '@m') - for d in docs: - _img_da = DocumentArray() - _txt_da = DocumentArray() - split_img_txt_da(d, _img_da, _txt_da) + await self.encode(docs) + await self.encode(docs[_source]) - candidates = getattr(d, _source) - - for c in candidates: - split_img_txt_da(c, _img_da, _txt_da) - - if len(_img_da) != 1 and len(_txt_da) != 1: - raise ValueError( - f'`d.{_source}` must be all in same modality, either all images or all text' - ) - elif not _img_da or not _txt_da: - raise ValueError( - f'`d` and `d.{_source}` must be in different modality, one is image one is text' - ) - elif len(candidates) <= 1: - raise ValueError( - f'`d.{_source}` must have more than one Documents to do ranking' - ) - else: - _img_da = await self.encode(_img_da) - _txt_da = await self.encode(_txt_da) - - # normalized features - image_features = _img_da.embeddings / np.linalg.norm( - _img_da.embeddings, axis=1, keepdims=True - ) - text_features = _txt_da.embeddings / np.linalg.norm( - _txt_da.embeddings, axis=1, keepdims=True - ) - - # cosine similarity as rank score - scores_per_text = np.matmul(image_features, text_features.T) - scores_per_image = scores_per_text.T - - if len(_img_da) == 1: - cosine_scores = scores_per_text - elif len(_txt_da) == 1: - cosine_scores = scores_per_image - - softmax_scores = numpy_softmax(self._logit_scale * cosine_scores) - - # squeeze scores - softmax_scores = softmax_scores[0] - cosine_scores = cosine_scores[0] - - # drop embeddings - _img_da.embeddings = None - _txt_da.embeddings = None - - for c, p, o in zip(candidates, softmax_scores, cosine_scores): - c.scores['clip_score'].value = p - c.scores['clip_score'].op_name = 'softmax' - - c.scores['clip_score_cosine'].value = o - c.scores['clip_score_cosine'].op_name = 'cosine' - - setattr( - d, - _source, - sorted( - candidates, - key=lambda _m: _m.scores['clip_score'].value, - reverse=True, - ), - ) + set_rank(docs, _source) @requests async def encode(self, docs: 'DocumentArray', **kwargs): diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index ff64d1c0e..4fa6d2ca3 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -1,17 +1,21 @@ import os import warnings from functools import partial - from multiprocessing.pool import ThreadPool from typing import Optional, Dict import numpy as np import torch -from clip_server.model import clip -from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text - from jina import Executor, requests, DocumentArray +from clip_server.executors.helper import ( + split_img_txt_da, + preproc_image, + preproc_text, + set_rank, +) +from clip_server.model import clip + class CLIPEncoder(Executor): def __init__( @@ -53,91 +57,18 @@ def __init__( self._model, self._preprocess_tensor = clip.load( name, device=self._device, jit=jit ) - self._logit_scale = self._model.logit_scale.exp() self._pool = ThreadPool(processes=num_worker_preprocess) @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - import torch - - _source = parameters.get('source', 'matches') - - for d in docs: - _img_da = DocumentArray() - _txt_da = DocumentArray() - split_img_txt_da(d, _img_da, _txt_da) - candidates = getattr(d, _source) + _source = parameters.get('source', '@m') - for c in candidates: - split_img_txt_da(c, _img_da, _txt_da) + await self.encode(docs) + await self.encode(docs[_source]) - if len(_img_da) != 1 and len(_txt_da) != 1: - raise ValueError( - f'`d.{_source}` must be all in same modality, either all images or all text' - ) - elif not _img_da or not _txt_da: - raise ValueError( - f'`d` and `d.{_source}` must be in different modality, one is image one is text' - ) - elif len(candidates) <= 1: - raise ValueError( - f'`d.{_source}` must have more than one Documents to do ranking' - ) - else: - _img_da = await self.encode(_img_da) - _txt_da = await self.encode(_txt_da) - _img_da.embeddings = torch.from_numpy(_img_da.embeddings).to( - self._device, non_blocking=True - ) - _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings).to( - self._device, non_blocking=True - ) - - # normalized features - image_features = _img_da.embeddings / _img_da.embeddings.norm( - dim=-1, keepdim=True - ) - text_features = _txt_da.embeddings / _txt_da.embeddings.norm( - dim=-1, keepdim=True - ) - - # paired cosine between image and text - scores_per_text = image_features @ text_features.t() - scores_per_image = scores_per_text.t() - - if len(_img_da) == 1: - cosine_scores = scores_per_text - elif len(_txt_da) == 1: - cosine_scores = scores_per_image - - softmax_scores = self._logit_scale * cosine_scores - softmax_scores = softmax_scores.softmax(dim=-1) - - # squeeze scores - cosine_scores = cosine_scores.cpu().detach().numpy().squeeze() - softmax_scores = softmax_scores.cpu().detach().numpy().squeeze() - - _img_da.embeddings = None - _txt_da.embeddings = None - - for c, p, o in zip(candidates, softmax_scores, cosine_scores): - c.scores['clip_score'].value = p - c.scores['clip_score'].op_name = 'softmax' - - c.scores['clip_score_cosine'].value = o - c.scores['clip_score_cosine'].op_name = 'cosine' - - setattr( - d, - _source, - sorted( - candidates, - key=lambda _m: _m.scores['clip_score'].value, - reverse=True, - ), - ) + set_rank(docs, _source) @requests async def encode(self, docs: 'DocumentArray', **kwargs): diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index 41882816d..f36290083 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -1,9 +1,10 @@ -from typing import Tuple, List, Callable, TYPE_CHECKING +from typing import Tuple, List, Callable + import numpy as np -from clip_server.model import clip +from docarray import Document, DocumentArray +from docarray.math.distance.numpy import cosine -if TYPE_CHECKING: - from docarray import Document, DocumentArray +from clip_server.model import clip def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray': @@ -52,9 +53,46 @@ def preproc_text( def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'): - if doc.text: - txt_da.append(doc) - elif doc.blob or (doc.tensor is not None): + if doc.uri: img_da.append(doc) - elif doc.uri: + elif doc.blob or (doc.tensor is not None): img_da.append(doc) + elif doc.text: + txt_da.append(doc) + + +def set_rank(docs, _source, _logit_scale=np.exp(4.60517)): + queries = docs + candidates = docs[_source] + + query_embeddings = queries.embeddings # Q X D + candidate_embeddings = candidates.embeddings # C = Sum(C_q1, C_q2, C_q3,...) x D + cosine_scores = cosine(query_embeddings, candidate_embeddings) # Q x C Block matix + start_idx = 0 + for q, _cosine_scores in zip(docs, cosine_scores): + + _candidates = DocumentArray(q)[_source] + + end_idx = start_idx + len(_candidates) + + _candidate_cosines = _cosine_scores[start_idx:end_idx] + _candidate_softmaxs = numpy_softmax(_logit_scale * _candidate_cosines) + for c, _c_score, _s_score in zip( + _candidates, _candidate_cosines, _candidate_softmaxs + ): + c.scores['clip_score'].value = _s_score + c.scores['clip_score'].op_name = 'softmax' + + c.scores['clip_score_cosine'].value = _c_score + c.scores['clip_score_cosine'].op_name = 'cosine' + + start_idx = end_idx + + final = sorted( + _candidates, key=lambda _m: _m.scores['clip_score'].value, reverse=True + ) + + if _source == '@m': + q.matches = final + elif _source == '@c': + q.chunks = final diff --git a/tests/test_ranker.py b/tests/test_ranker.py index 5fa2b966d..2269035e3 100644 --- a/tests/test_ranker.py +++ b/tests/test_ranker.py @@ -1,11 +1,12 @@ import os -import pytest import numpy as np +import pytest +from docarray import DocumentArray, Document + from clip_client import Client -from clip_server.executors.clip_torch import CLIPEncoder as TorchCLIPEncoder from clip_server.executors.clip_onnx import CLIPEncoder as ONNXCLILPEncoder -from docarray import DocumentArray, Document +from clip_server.executors.clip_torch import CLIPEncoder as TorchCLIPEncoder @pytest.mark.asyncio @@ -19,12 +20,18 @@ async def test_torch_executor_rank_img2texts(encoder_class): for d in da: d.matches.append(Document(text='hello, world!')) d.matches.append(Document(text='goodbye, world!')) + d.matches.append(Document(text='goodbye,!')) + d.matches.append(Document(text='good world!')) + d.matches.append(Document(text='good!')) + d.matches.append(Document(text='world!')) await ce.rank(da, {}) print(da['@m', 'scores__clip_score__value']) for d in da: for c in d.matches: assert c.scores['clip_score'].value is not None + org_score = d.matches[:, 'scores__clip_score__value'] + assert org_score == list(sorted(org_score, reverse=True)) @pytest.mark.asyncio @@ -45,6 +52,10 @@ async def test_torch_executor_rank_text2imgs(encoder_class): for d in db: for c in d.matches: assert c.scores['clip_score'].value is not None + assert c.scores['clip_score_cosine'].value is not None + np.testing.assert_almost_equal( + sum(c.scores['clip_score'].value for c in d.matches), 1 + ) @pytest.mark.parametrize( @@ -69,10 +80,13 @@ def test_docarray_inputs(make_flow, d): c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') r = c.rank([d]) assert isinstance(r, DocumentArray) - rv = r['@m', 'scores__clip_score__value'] - for v in rv: - assert v is not None - assert v > 0 + rv1 = r['@m', 'scores__clip_score__value'] + rv2 = r['@m', 'scores__clip_score_cosine__value'] + for v1, v2 in zip(rv1, rv2): + assert v1 is not None + assert v1 > 0 + assert v2 is not None + assert v2 > 0 @pytest.mark.parametrize(