Skip to content

Commit

Permalink
fix: typo
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Sep 9, 2022
1 parent fa2c2be commit 019f286
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _prepare_streaming(self, disable, total):
)

def _iter_doc(
self, content, result: 'DocumentArray'
self, content, results: 'DocumentArray'
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document
Expand Down Expand Up @@ -189,15 +189,15 @@ def _iter_doc(
)

d.tags['__ordered_by_CAS__'] = i
result.append(d)
results.append(d)
yield d

def _get_post_payload(self, content, result: 'DocumentArray', kwargs):
def _get_post_payload(self, content, results: 'DocumentArray', kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on=f'/encode/{model_name}'.rstrip('/'),
inputs=self._iter_doc(content, result),
inputs=self._iter_doc(content, results),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
Expand All @@ -208,14 +208,14 @@ def _get_post_payload(self, content, result: 'DocumentArray', kwargs):
payload.update(headers={'Authorization': self._authorization})
return payload

def _gather_encode_result(self, response, result: 'DocumentArray'):
def _gather_encode_result(self, response, results: 'DocumentArray'):
from rich import filesize

r = response.data.docs
for d in r:
index = int(d.tags['__ordered_by_CAS__'])
result[index].embedding = d.embedding
result[index].tags.pop('__ordered_by_CAS__')
results[index].embedding = d.embedding
results[index].tags.pop('__ordered_by_CAS__')

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -291,19 +291,19 @@ def encode(self, content, **kwargs):
total=len(content) if hasattr(content, '__len__') else None,
)

result = DocumentArray()
results = DocumentArray()
with self._pbar:
self._client.post(
**self._get_post_payload(content, result, kwargs),
on_done=partial(self._gather_encode_result, result=result),
**self._get_post_payload(content, results, kwargs),
on_done=partial(self._gather_encode_result, results=results),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(result, unbox)
return self._unboxed_result(results, unbox)

@overload
async def aencode(
Expand Down Expand Up @@ -340,15 +340,15 @@ async def aencode(self, content, **kwargs):
total=len(content) if hasattr(content, '__len__') else None,
)

result = DocumentArray()
results = DocumentArray()
with self._pbar:
async for da in self._async_client.post(
**self._get_post_payload(content, result, kwargs)
**self._get_post_payload(content, results, kwargs)
):
for d in da:
index = int(d.tags['__ordered_by_CAS__'])
result[index].embedding = d.embedding
result[index].tags.pop('__ordered_by_CAS__')
results[index].embedding = d.embedding
results[index].tags.pop('__ordered_by_CAS__')

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand All @@ -367,10 +367,10 @@ async def aencode(self, content, **kwargs):
c.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(result, unbox)
return self._unboxed_result(results, unbox)

def _iter_rank_docs(
self, content, result: 'DocumentArray', source='matches'
self, content, results: 'DocumentArray', source='matches'
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document
Expand All @@ -396,16 +396,16 @@ def _iter_rank_docs(
)

d.tags['__ordered_by_CAS__'] = i
result.append(d)
results.append(d)
yield d

def _get_rank_payload(self, content, result: 'DocumentArray', kwargs):
def _get_rank_payload(self, content, results: 'DocumentArray', kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on=f'/rank/{model_name}'.rstrip('/'),
inputs=self._iter_rank_docs(
content, result, source=kwargs.get('source', 'matches')
content, results, source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
Expand All @@ -416,14 +416,14 @@ def _get_rank_payload(self, content, result: 'DocumentArray', kwargs):
payload.update(headers={'Authorization': self._authorization})
return payload

def _gather_rank_result(self, response, result: 'DocumentArray'):
def _gather_rank_result(self, response, results: 'DocumentArray'):
from rich import filesize

r = response.data.docs
for d in r:
index = int(d.tags['__ordered_by_CAS__'])
result[index].matches = d.matches
result[index].tags.pop('__ordered_by_CAS__')
results[index].matches = d.matches
results[index].tags.pop('__ordered_by_CAS__')

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -486,16 +486,16 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
total=len(docs) if hasattr(docs, '__len__') else None,
)

result = DocumentArray()
results = DocumentArray()
with self._pbar:
self._client.post(
**self._get_rank_payload(docs, result, kwargs),
on_done=partial(self._gather_rank_result, result=result),
**self._get_rank_payload(docs, results, kwargs),
on_done=partial(self._gather_rank_result, results=results),
)
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return result
return results

async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
from rich import filesize
Expand All @@ -508,15 +508,15 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
total=len(docs) if hasattr(docs, '__len__') else None,
)

result = DocumentArray()
results = DocumentArray()
with self._pbar:
async for da in self._async_client.post(
**self._get_rank_payload(docs, result, kwargs)
**self._get_rank_payload(docs, results, kwargs)
):
for d in da:
index = int(d.tags['__ordered_by_CAS__'])
result[index].matches = d.matches
result[index].tags.pop('__ordered_by_CAS__')
results[index].matches = d.matches
results[index].tags.pop('__ordered_by_CAS__')

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand All @@ -533,4 +533,4 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return result
return results

0 comments on commit 019f286

Please sign in to comment.