Skip to content

Commit

Permalink
feat(server): add rank endpoint (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Apr 25, 2022
1 parent 22cfffa commit 0ebc4c0
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 70 deletions.
51 changes: 51 additions & 0 deletions scripts/onnx_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
def convert_float_to_float16(model_path: str, output_model_path: str):
import onnx
from onnxmltools.utils.float16_converter import (
convert_float_to_float16_model_path,
)

new_onnx_model = convert_float_to_float16_model_path(model_path)

onnx.save(new_onnx_model, output_model_path)

# Alternate approach
# from onnx import load_model
# from onnxruntime.transformers import optimizer, onnx_model
#
# # optimized_model = optimizer.optimize_model(model_path, model_type='bert')
#
# model = load_model(model_path)
# optimized_model = onnx_model.OnnxModel(model)
#
# if hasattr(optimized_model, 'convert_float32_to_float16'):
# optimized_model.convert_float_to_float16()
# else:
# optimized_model.convert_model_float32_to_float16()
#
# self._textual_path = f'{self._textual_path[:-5]}_optimized.onnx'
# optimized_model.save_model_to_file(output_model_path)


def quantize(model_path: str, output_model_path: str):
"""
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input=model_path,
model_output=output_model_path,
per_channel=True,
reduce_range=True, # should be the same as per_channel
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
optimize_model=True,
op_types_to_quantize=["MatMul", "Attention", "Mul", "Add"],
extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
8 changes: 6 additions & 2 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'ViT-B/32': 224,
'ViT-B/16': 224,
'ViT-L/14': 224,
'ViT-L/14@336px': 336,
}


Expand Down Expand Up @@ -67,11 +68,14 @@ def __init__(

if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
and hasattr(self.runtime_args, 'replicas')
):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas (replicas={self.runtime_args.replicas})'
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)

# Run the operators in the graph in parallel (not support the CUDA Execution Provider)
Expand Down
72 changes: 59 additions & 13 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import numpy as np
from multiprocessing.pool import ThreadPool
from typing import Optional, List, Tuple

import numpy as np
from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger

Expand Down Expand Up @@ -31,11 +31,14 @@ def __init__(

if not self._device.startswith('cuda') and (
not os.environ.get('OMP_NUM_THREADS')
and hasattr(self.runtime_args, 'replicas')
):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas (replicas={self.runtime_args.replicas})'
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)

# NOTE: make sure to set the threads right after the torch import,
Expand Down Expand Up @@ -70,21 +73,64 @@ def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]
da[:, 'mime_type'] = 'text'
return da, texts

@staticmethod
def _split_img_txt_da(d, _img_da, _txt_da):
if d.text:
_txt_da.append(d)
elif (d.blob is not None) or (d.tensor is not None):
_img_da.append(d)
elif d.uri:
_img_da.append(d)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', **kwargs):
for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
self._split_img_txt_da(d, _img_da, _txt_da)

for c in d.chunks:
self._split_img_txt_da(c, _img_da, _txt_da)

if len(_img_da) != 1 and len(_txt_da) != 1:
raise ValueError(
'chunks must be all in same modality, either all images or all text'
)
elif not _img_da or not _txt_da:
raise ValueError(
'root and chunks must be in different modality, one is image one is text'
)
elif len(d.chunks) <= 1:
raise ValueError('must have more than one chunks to rank over chunks')
else:
_img_da = self._preproc_image(_img_da)
_txt_da, texts = self._preproc_text(_txt_da)

logits_per_image, logits_per_text = self._model(
_img_da.tensors, _txt_da.tensors
)
probs_image = (
logits_per_image.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
probs_text = (
logits_per_text.softmax(dim=-1).cpu().detach().numpy().squeeze()
)
if len(_img_da) == 1:
probs = probs_image
elif len(_txt_da) == 1:
probs = probs_text

for c, v in zip(d.chunks, probs):
c.scores['clip-rank'].value = v

_txt_da.texts = texts

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
if d.text:
_txt_da.append(d)
elif (d.blob is not None) or (d.tensor is not None):
_img_da.append(d)
elif d.uri:
_img_da.append(d)
else:
self.logger.warning(
f'The content of document {d.id} is empty, cannot be processed'
)
self._split_img_txt_da(d, _img_da, _txt_da)

import torch

Expand Down
5 changes: 3 additions & 2 deletions server/clip_server/model/clip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Originally from https://github.com/openai/CLIP. MIT License, Copyright (c) 2021 OpenAI

import os
import io
import urllib
import os
import shutil
import urllib
import warnings
from typing import Union, List

Expand Down Expand Up @@ -34,6 +34,7 @@
'ViT-B/32': 'ViT-B-32.pt',
'ViT-B/16': 'ViT-B-16.pt',
'ViT-L/14': 'ViT-L-14.pt',
'ViT-L/14@336px': 'ViT-L-14-336px.pt',
}


Expand Down
56 changes: 3 additions & 53 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import onnx

import onnxruntime as ort

from .clip import _download, available_models

_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/'
Expand All @@ -13,6 +14,7 @@
'ViT-B/32': ('ViT-B-32/textual.onnx', 'ViT-B-32/visual.onnx'),
'ViT-B/16': ('ViT-B-16/textual.onnx', 'ViT-B-16/visual.onnx'),
'ViT-L/14': ('ViT-L-14/textual.onnx', 'ViT-L-14/visual.onnx'),
'ViT-L/14@336px': ('ViT-L-14@336px/textual.onnx', 'ViT-L-14@336px/visual.onnx'),
}


Expand Down Expand Up @@ -50,55 +52,3 @@ def encode_text(self, onnx_text):
onnx_input_text = {self._textual_session.get_inputs()[0].name: onnx_text}
(textual_output,) = self._textual_session.run(None, onnx_input_text)
return textual_output


def convert_float_to_float16(model_path: str, output_model_path: str):
from onnxmltools.utils.float16_converter import (
convert_float_to_float16_model_path,
)

new_onnx_model = convert_float_to_float16_model_path(model_path)

onnx.save(new_onnx_model, output_model_path)

# Alternate approach
# from onnx import load_model
# from onnxruntime.transformers import optimizer, onnx_model
#
# # optimized_model = optimizer.optimize_model(model_path, model_type='bert')
#
# model = load_model(model_path)
# optimized_model = onnx_model.OnnxModel(model)
#
# if hasattr(optimized_model, 'convert_float32_to_float16'):
# optimized_model.convert_float_to_float16()
# else:
# optimized_model.convert_model_float32_to_float16()
#
# self._textual_path = f'{self._textual_path[:-5]}_optimized.onnx'
# optimized_model.save_model_to_file(output_model_path)


def quantize(model_path: str, output_model_path: str):
"""
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input=model_path,
model_output=output_model_path,
per_channel=True,
reduce_range=True, # should be the same as per_channel
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
optimize_model=True,
op_types_to_quantize=["MatMul", "Attention", "Mul", "Add"],
extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
42 changes: 42 additions & 0 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

import pytest
from clip_server.executors.clip_torch import CLIPEncoder
from docarray import DocumentArray, Document


@pytest.mark.asyncio
async def test_torch_executor_rank_img2texts():
ce = CLIPEncoder()

da = DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
)
for d in da:
d.chunks.append(Document(text='hello, world!'))
d.chunks.append(Document(text='goodbye, world!'))

await ce.rank(da)
print(da['@c', 'scores__clip-rank__value'])
for d in da:
for c in d.chunks:
assert c.scores['clip-rank'].value is not None


@pytest.mark.asyncio
async def test_torch_executor_rank_text2imgs():
ce = CLIPEncoder()
db = DocumentArray(
[Document(text='hello, world!'), Document(text='goodbye, world!')]
)
for d in db:
d.chunks.extend(
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
)
)
await ce.rank(db)
print(db['@c', 'scores__clip-rank__value'])
for d in db:
for c in d.chunks:
assert c.scores['clip-rank'].value is not None

0 comments on commit 0ebc4c0

Please sign in to comment.