From 32b11cd64bb76bca5075fbcbc84b9334952c236c Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Wed, 20 Jul 2022 22:32:12 +0800 Subject: [PATCH] feat: allow model selection in client (#775) * feat: allow model selection in client * docs: update client model selection * docs: revert * fix: improve endpoint * fix: rstrip endpoint --- client/clip_client/client.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/client/clip_client/client.py b/client/clip_client/client.py index 7b1bf1e0e..9a2b457c2 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -69,6 +69,7 @@ def encode( *, batch_size: Optional[int] = None, show_progress: bool = False, + parameters: Optional[dict] = None, ) -> 'np.ndarray': """Encode images and texts into embeddings where the input is an iterable of raw strings. Each image and text must be represented as a string. The following strings are acceptable: @@ -79,6 +80,7 @@ def encode( :param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar + :param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ... @@ -90,11 +92,13 @@ def encode( *, batch_size: Optional[int] = None, show_progress: bool = False, + parameters: Optional[dict] = None, ) -> 'DocumentArray': """Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`. :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar + :param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ... @@ -185,8 +189,10 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: ) def _get_post_payload(self, content, kwargs): + parameters = kwargs.get('parameters', {}) + model_name = parameters.get('model', '') payload = dict( - on='/', + on=f'/encode/{model_name}'.rstrip('/'), inputs=self._iter_doc(content), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, @@ -364,8 +370,10 @@ def _iter_rank_docs( ) def _get_rank_payload(self, content, kwargs): + parameters = kwargs.get('parameters', {}) + model_name = parameters.get('model', '') payload = dict( - on='/rank', + on=f'/rank/{model_name}'.rstrip('/'), inputs=self._iter_rank_docs( content, _source=kwargs.get('source', 'matches') ),