Skip to content

Commit

Permalink
feat: allow credential in client (#765)
Browse files Browse the repository at this point in the history
* feat: allow credential in client

* feat: add credential wrapper

* feat: add credential wrapper

* fix: credential is a dict

* fix: default credential

* fix: remove redundancy

* fix: remove unused import

* fix: warning at ws

* fix: typo
  • Loading branch information
ZiniuYu authored Jul 20, 2022
1 parent ca03dca commit 0ff4e25
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,36 @@


class Client:
def __init__(self, server: str):
def __init__(self, server: str, credential: dict = {}, **kwargs):
"""Create a Clip client object that connects to the Clip server.
Server scheme is in the format of `scheme://netloc:port`, where
- scheme: one of grpc, websocket, http, grpcs, websockets, https
- netloc: the server ip address or hostname
- port: the public port of the server
:param server: the server URI
:param credential: the credential for authentication {'Authentication': '<token>'}
"""
try:
r = urlparse(server)
_port = r.port
_scheme = r.scheme
if not _scheme:
raise
self._scheme = r.scheme
except:
raise ValueError(f'{server} is not a valid scheme')

_tls = False

if _scheme in ('grpcs', 'https', 'wss'):
_scheme = _scheme[:-1]
if self._scheme in ('grpcs', 'https', 'wss'):
self._scheme = self._scheme[:-1]
_tls = True

if _scheme == 'ws':
_scheme = 'websocket' # temp fix for the core
if self._scheme == 'ws':
self._scheme = 'websocket' # temp fix for the core
if credential:
warnings.warn(
'Credential is not supported for websocket, please use grpc or http'
)

if _scheme in ('grpc', 'http', 'websocket'):
_kwargs = dict(host=r.hostname, port=_port, protocol=_scheme, tls=_tls)
if self._scheme in ('grpc', 'http', 'websocket'):
_kwargs = dict(host=r.hostname, port=_port, protocol=self._scheme, tls=_tls)

from jina import Client

Expand All @@ -58,6 +60,8 @@ def __init__(self, server: str):
else:
raise ValueError(f'{server} is not a valid scheme')

self._authorization = credential.get('Authorization', None)

@overload
def encode(
self,
Expand Down Expand Up @@ -181,12 +185,17 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
)

def _get_post_payload(self, content, kwargs):
return dict(
payload = dict(
on='/',
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
if self._scheme == 'grpc' and self._authorization:
payload.update(metadata=('authorization', self._authorization))
elif self._scheme == 'http' and self._authorization:
payload.update(headers={'Authorization': self._authorization})
return payload

def profile(self, content: Optional[str] = '') -> Dict[str, float]:
"""Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table.
Expand Down Expand Up @@ -355,14 +364,19 @@ def _iter_rank_docs(
)

def _get_rank_payload(self, content, kwargs):
return dict(
payload = dict(
on='/rank',
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
if self._scheme == 'grpc' and self._authorization:
payload.update(metadata=('authorization', self._authorization))
elif self._scheme == 'http' and self._authorization:
payload.update(headers={'Authorization': self._authorization})
return payload

def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rank image-text matches according to the server CLIP model.
Expand Down

0 comments on commit 0ff4e25

Please sign in to comment.