diff --git a/client/clip_client/client.py b/client/clip_client/client.py index 126b5a167..7b1bf1e0e 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -22,34 +22,36 @@ class Client: - def __init__(self, server: str): + def __init__(self, server: str, credential: dict = {}, **kwargs): """Create a Clip client object that connects to the Clip server. Server scheme is in the format of `scheme://netloc:port`, where - scheme: one of grpc, websocket, http, grpcs, websockets, https - netloc: the server ip address or hostname - port: the public port of the server :param server: the server URI + :param credential: the credential for authentication {'Authentication': ''} """ try: r = urlparse(server) _port = r.port - _scheme = r.scheme - if not _scheme: - raise + self._scheme = r.scheme except: raise ValueError(f'{server} is not a valid scheme') _tls = False - - if _scheme in ('grpcs', 'https', 'wss'): - _scheme = _scheme[:-1] + if self._scheme in ('grpcs', 'https', 'wss'): + self._scheme = self._scheme[:-1] _tls = True - if _scheme == 'ws': - _scheme = 'websocket' # temp fix for the core + if self._scheme == 'ws': + self._scheme = 'websocket' # temp fix for the core + if credential: + warnings.warn( + 'Credential is not supported for websocket, please use grpc or http' + ) - if _scheme in ('grpc', 'http', 'websocket'): - _kwargs = dict(host=r.hostname, port=_port, protocol=_scheme, tls=_tls) + if self._scheme in ('grpc', 'http', 'websocket'): + _kwargs = dict(host=r.hostname, port=_port, protocol=self._scheme, tls=_tls) from jina import Client @@ -58,6 +60,8 @@ def __init__(self, server: str): else: raise ValueError(f'{server} is not a valid scheme') + self._authorization = credential.get('Authorization', None) + @overload def encode( self, @@ -181,12 +185,17 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: ) def _get_post_payload(self, content, kwargs): - return dict( + payload = dict( on='/', inputs=self._iter_doc(content), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, ) + if self._scheme == 'grpc' and self._authorization: + payload.update(metadata=('authorization', self._authorization)) + elif self._scheme == 'http' and self._authorization: + payload.update(headers={'Authorization': self._authorization}) + return payload def profile(self, content: Optional[str] = '') -> Dict[str, float]: """Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table. @@ -355,7 +364,7 @@ def _iter_rank_docs( ) def _get_rank_payload(self, content, kwargs): - return dict( + payload = dict( on='/rank', inputs=self._iter_rank_docs( content, _source=kwargs.get('source', 'matches') @@ -363,6 +372,11 @@ def _get_rank_payload(self, content, kwargs): request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, ) + if self._scheme == 'grpc' and self._authorization: + payload.update(metadata=('authorization', self._authorization)) + elif self._scheme == 'http' and self._authorization: + payload.update(headers={'Authorization': self._authorization}) + return payload def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': """Rank image-text matches according to the server CLIP model.