Skip to content

Commit

Permalink
feat: add custom tracing spans with jina>=3.11.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Girish Chandrashekar committed Nov 16, 2022
1 parent 9bb7d1f commit 9f197b5
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 170 deletions.
2 changes: 1 addition & 1 deletion client/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
long_description_content_type='text/markdown',
zip_safe=False,
setup_requires=['setuptools>=18.0', 'wheel'],
install_requires=['jina>=3.8.0', 'docarray[common]>=0.16.1', 'packaging'],
install_requires=['jina>=3.11.0', 'docarray[common]>=0.19.0', 'packaging'],
extras_require={
'test': [
'pytest',
Expand Down
116 changes: 71 additions & 45 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from clip_server.model.clip_onnx import CLIPOnnxModel
from clip_server.model.tokenization import Tokenizer
from jina import Executor, requests, DocumentArray
from opentelemetry.trace import NoOpTracer, Span


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
)
self._access_paths = kwargs['traversal_paths']

self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name, model_path)
Expand Down Expand Up @@ -100,24 +102,29 @@ def __init__(

self._model.start_sessions(sess_options=sess_options, providers=providers)

if not self.tracer:
self.tracer = NoOpTracer()

def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
):
return preproc_image(
docs,
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
)
with self.tracer.start_as_current_span('preprocess_images'):
return preproc_image(
docs,
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
with self.monitor(
name='preprocess_texts_seconds',
documentation='texts preprocess time in seconds',
):
return preproc_text(docs, tokenizer=self._tokenizer, return_np=True)
with self.tracer.start_as_current_span('preprocess_images'):
return preproc_text(docs, tokenizer=self._tokenizer, return_np=True)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
Expand All @@ -128,43 +135,62 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):

@requests
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = self._model.encode_image(batch_data)

# for text
if _txt_da:
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = self._model.encode_text(batch_data)
with self.tracer.start_as_current_span(
'encode', context=tracing_context
) as span:
span.set_attribute('device', self._device)
span.set_attribute('runtime', 'onnx')
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

with self.tracer.start_as_current_span('inference') as inference_span:
# for image
if _img_da:
with self.tracer.start_as_current_span(
'img_minibatch_encoding'
) as img_encode_span:
for minibatch, batch_data in _img_da.map_batch(
partial(
self._preproc_images,
drop_image_content=_drop_image_content,
),
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = self._model.encode_image(
batch_data
)

# for text
if _txt_da:
with self.tracer.start_as_current_span(
'txt_minibatch_encoding'
) as txt_encode_span:
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = self._model.encode_text(
batch_data
)

return docs
149 changes: 88 additions & 61 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from clip_server.model.tokenization import Tokenizer
from clip_server.model.clip_trt import CLIPTensorRTModel
from jina import Executor, requests, DocumentArray
from opentelemetry.trace import NoOpTracer, Span


class CLIPEncoder(Executor):
Expand All @@ -38,6 +39,7 @@ def __init__(
"""
super().__init__(**kwargs)

self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size
Expand Down Expand Up @@ -68,27 +70,35 @@ def __init__(
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

if not self.tracer:
self.tracer = NoOpTracer()

def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
):
return preproc_image(
docs,
preprocess_fn=self._image_transform,
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
)
with self.tracer.start_as_current_span('preprocess_images'):
return preproc_image(
docs,
preprocess_fn=self._image_transform,
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
with self.monitor(
name='preprocess_texts_seconds',
documentation='texts preprocess time in seconds',
):
return preproc_text(
docs, tokenizer=self._tokenizer, device=self._device, return_np=False
)
with self.tracer.start_as_current_span('preprocess_images'):
return preproc_text(
docs,
tokenizer=self._tokenizer,
device=self._device,
return_np=False,
)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
Expand All @@ -98,56 +108,73 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_image(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)

# for text
if _txt_da:
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_text(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
async def encode(
self, docs: 'DocumentArray', tracing_context, parameters: Dict = {}, **kwargs
):
with self.tracer.start_as_current_span(
'encode', context=tracing_context
) as span:
span.set_attribute('device', self._device)
span.set_attribute('runtime', 'tensorrt')
access_paths = parameters.get('access_paths', self._access_paths)
if 'traversal_paths' in parameters:
warnings.warn(
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs[access_paths]:
split_img_txt_da(d, _img_da, _txt_da)

with self.tracer.start_as_current_span('inference') as inference_span:
# for image
if _img_da:
with self.tracer.start_as_current_span(
'img_minibatch_encoding'
) as img_encode_span:
for minibatch, batch_data in _img_da.map_batch(
partial(
self._preproc_images,
drop_image_content=_drop_image_content,
),
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_image(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)

# for text
if _txt_da:
with self.tracer.start_as_current_span(
'txt_minibatch_encoding'
) as txt_encode_span:
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_text(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)

return docs
Loading

0 comments on commit 9f197b5

Please sign in to comment.