Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow model selection in client #775

Merged
merged 5 commits into from
Jul 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def encode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
Expand All @@ -79,6 +80,7 @@ def encode(
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -90,11 +92,13 @@ def encode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand Down Expand Up @@ -185,8 +189,10 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
)

def _get_post_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on='/',
on=f'/encode/{model_name}'.rstrip('/'),
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
Expand Down Expand Up @@ -364,8 +370,10 @@ def _iter_rank_docs(
)

def _get_rank_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on='/rank',
on=f'/rank/{model_name}'.rstrip('/'),
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
Expand Down