Skip to content

Commit

Permalink
perf(server): use await gather in rank function (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored May 11, 2022
1 parent 66b14fc commit 72f1bc4
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 30 deletions.
10 changes: 3 additions & 7 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import asyncio
import os
import warnings
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict

import onnxruntime as ort
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
Expand All @@ -16,6 +13,7 @@
)
from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -82,11 +80,9 @@ def __init__(

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', '@m')

await asyncio.gather(self.encode(docs), self.encode(docs[_source]))
await self.encode(docs['@r,m'])

set_rank(docs, _source)
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand Down
10 changes: 3 additions & 7 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import asyncio
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Dict

import numpy as np
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
Expand All @@ -14,6 +11,7 @@
)
from clip_server.model import clip
from clip_server.model.clip_trt import CLIPTensorRTModel
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -50,11 +48,9 @@ def __init__(

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
_source = parameters.get('source', '@m')

await asyncio.gather(self.encode(docs), self.encode(docs[_source]))
await self.encode(docs['@r,m'])

set_rank(docs, _source)
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand Down
10 changes: 3 additions & 7 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
import warnings
from functools import partial
Expand All @@ -7,15 +6,14 @@

import numpy as np
import torch
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
set_rank,
)
from clip_server.model import clip
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -64,11 +62,9 @@ def __init__(
@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):

_source = parameters.get('source', '@m')

await asyncio.gather(self.encode(docs), self.encode(docs[_source]))
await self.encode(docs['@r,m'])

set_rank(docs, _source)
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
Expand Down
14 changes: 5 additions & 9 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Tuple, List, Callable

import numpy as np
from clip_server.model import clip
from docarray import Document, DocumentArray
from docarray.math.distance.numpy import cosine

from clip_server.model import clip


def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray':
max = np.max(x, axis=axis, keepdims=True)
Expand Down Expand Up @@ -61,9 +60,9 @@ def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'Document
txt_da.append(doc)


def set_rank(docs, _source, _logit_scale=np.exp(4.60517)):
def set_rank(docs, _logit_scale=np.exp(4.60517)):
queries = docs
candidates = docs[_source]
candidates = docs['@m']

query_embeddings = queries.embeddings # Q X D
candidate_embeddings = candidates.embeddings # C = Sum(C_q1, C_q2, C_q3,...) x D
Expand All @@ -73,7 +72,7 @@ def set_rank(docs, _source, _logit_scale=np.exp(4.60517)):
start_idx = 0
for q, _cosine_scores in zip(docs, cosine_scores):

_candidates = DocumentArray(q)[_source]
_candidates = q.matches

end_idx = start_idx + len(_candidates)

Expand All @@ -94,7 +93,4 @@ def set_rank(docs, _source, _logit_scale=np.exp(4.60517)):
_candidates, key=lambda _m: _m.scores['clip_score'].value, reverse=True
)

if _source == '@m':
q.matches = final
elif _source == '@c':
q.chunks = final
q.matches = final

0 comments on commit 72f1bc4

Please sign in to comment.