Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use cosine as the rank score #708

Merged
merged 7 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
numpy_softmax,
)


class CLIPEncoder(Executor):
Expand All @@ -20,7 +25,6 @@ def __init__(
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 16,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -31,7 +35,8 @@ def __init__(
self._minibatch_size = minibatch_size

self._model = CLIPOnnxModel(name)
self._logit_scale = logit_scale
# Note: hard coded here since all the pretrained clip model use the same logit_scale parameter
self._logit_scale = np.exp(4.60517)

import torch

Expand Down Expand Up @@ -80,14 +85,15 @@ def __init__(
@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -98,7 +104,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
Expand All @@ -114,37 +120,37 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_txt_da.embeddings, axis=1, keepdims=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * np.matmul(
image_features, text_features.T
)
logits_per_text = logits_per_image.T

def numpy_softmax(z):
s = np.max(z, axis=1)
s = s[:, np.newaxis]
e_x = np.exp(z - s)
div = np.sum(e_x, axis=1)
div = div[:, np.newaxis] # dito
return e_x / div
# paired cosine similarity
scores_per_text = np.matmul(image_features, text_features.T)
scores_per_image = scores_per_text.T

if len(_img_da) == 1:
probs = numpy_softmax(logits_per_image)[0]
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = numpy_softmax(logits_per_text)[0]
cosine_scores = scores_per_image

softmax_scores = numpy_softmax(self._logit_scale * cosine_scores)

# squeeze scores
cosine_scores = cosine_scores[0]
softmax_scores = softmax_scores[0]

# drop embeddings
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
47 changes: 26 additions & 21 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial

from multiprocessing.pool import ThreadPool
from typing import Optional, List, Tuple, Dict
from typing import Optional, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
self._model, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
)
self._logit_scale = self._model.logit_scale.exp()

self._pool = ThreadPool(processes=num_worker_preprocess)

Expand All @@ -61,14 +62,15 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -79,7 +81,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
Expand All @@ -97,34 +99,37 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
dim=-1, keepdim=True
)

# cosine similarity as logits
logit_scale = self._model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# paired cosine between image and text
scores_per_text = image_features @ text_features.t()
scores_per_image = scores_per_text.t()

if len(_img_da) == 1:
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
cosine_scores = scores_per_image

softmax_scores = self._logit_scale * cosine_scores
softmax_scores = softmax_scores.softmax(dim=-1)

# squeeze scores
cosine_scores = cosine_scores.cpu().detach().numpy().squeeze()
softmax_scores = softmax_scores.cpu().detach().numpy().squeeze()

_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
69 changes: 37 additions & 32 deletions server/clip_server/executors/clip_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from clip_server.model import clip
from clip_server.model.clip_trt import CLIPTensorRTModel
from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
numpy_softmax,
)


class CLIPEncoder(Executor):
Expand All @@ -16,7 +21,6 @@ def __init__(
device: str = 'cuda',
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
logit_scale: float = 4.60,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -27,8 +31,6 @@ def __init__(
self._minibatch_size = minibatch_size
self._device = device

self._logit_scale = logit_scale

import torch

assert self._device.startswith('cuda'), (
Expand All @@ -44,19 +46,21 @@ def __init__(

self._model.start_engines()

# Note: hard coded here since all the pretrained clip model use the same logit_scale parameter
self._logit_scale = np.exp(4.60517)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch

_source = parameters.get('source', 'matches')
_get = lambda d: getattr(d, _source)

for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
split_img_txt_da(d, _img_da, _txt_da)

for c in _get(d):
candidates = getattr(d, _source)

for c in candidates:
split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
Expand All @@ -67,52 +71,53 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
raise ValueError(
f'`d` and `d.{_source}` must be in different modality, one is image one is text'
)
elif len(_get(d)) <= 1:
elif len(candidates) <= 1:
raise ValueError(
f'`d.{_source}` must have more than one Documents to do ranking'
)
else:
_img_da = await self.encode(_img_da)
_txt_da = await self.encode(_txt_da)
_img_da.embeddings = torch.from_numpy(_img_da.embeddings)
_txt_da.embeddings = torch.from_numpy(_txt_da.embeddings)

# normalized features
image_features = _img_da.embeddings / _img_da.embeddings.norm(
dim=-1, keepdim=True
image_features = _img_da.embeddings / np.linalg.norm(
_img_da.embeddings, axis=1, keepdims=True
)
text_features = _txt_da.embeddings / _txt_da.embeddings.norm(
dim=-1, keepdim=True
text_features = _txt_da.embeddings / np.linalg.norm(
_txt_da.embeddings, axis=1, keepdims=True
)

# cosine similarity as logits
logit_scale = np.exp(self._logit_scale)
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# cosine similarity as rank score
scores_per_text = np.matmul(image_features, text_features.T)
scores_per_image = scores_per_text.T
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved

if len(_img_da) == 1:
probs = (
logits_per_image.softmax(dim=-1)
.cpu()
.detach()
.numpy()
.squeeze()
)
cosine_scores = scores_per_text
elif len(_txt_da) == 1:
probs = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
cosine_scores = scores_per_image

softmax_scores = numpy_softmax(self._logit_scale * cosine_scores)

# squeeze scores
softmax_scores = softmax_scores[0]
cosine_scores = cosine_scores[0]

# drop embeddings
_img_da.embeddings = None
_txt_da.embeddings = None

for c, v in zip(_get(d), probs):
c.scores['clip_score'].value = v
for c, p, o in zip(candidates, softmax_scores, cosine_scores):
c.scores['clip_score'].value = p
c.scores['clip_score'].op_name = 'softmax'

c.scores['clip_score_cosine'].value = o
c.scores['clip_score_cosine'].op_name = 'cosine'

setattr(
d,
_source,
sorted(
_get(d),
candidates,
key=lambda _m: _m.scores['clip_score'].value,
reverse=True,
),
Expand Down
8 changes: 8 additions & 0 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from docarray import Document, DocumentArray


def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray':
max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - max)
div = np.sum(e_x, axis=axis, keepdims=True)
f_x = e_x / div
return f_x


def preproc_image(
da: 'DocumentArray',
preprocess_fn: Callable,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import numpy as np
from clip_server.executors.helper import numpy_softmax


@pytest.mark.parametrize('shape', [(5, 10), (5, 10, 10)])
@pytest.mark.parametrize('axis', [-1, 1, 0])
def test_numpy_softmax(shape, axis):
import torch

logits = np.random.random(shape)

np_softmax = numpy_softmax(logits, axis=axis)
torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy()
np.testing.assert_array_almost_equal(np_softmax, torch_softmax)

np_softmax = numpy_softmax(logits, axis=axis)
torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy()
np.testing.assert_array_almost_equal(np_softmax, torch_softmax)
Loading