Skip to content

Commit

Permalink
feat: use uid to ensure the order
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Sep 13, 2022
1 parent 019f286 commit 4d41716
Showing 1 changed file with 4 additions and 18 deletions.
22 changes: 4 additions & 18 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def _iter_doc(
),
)

d.tags['__ordered_by_CAS__'] = i
results.append(d)
yield d

Expand All @@ -212,10 +211,7 @@ def _gather_encode_result(self, response, results: 'DocumentArray'):
from rich import filesize

r = response.data.docs
for d in r:
index = int(d.tags['__ordered_by_CAS__'])
results[index].embedding = d.embedding
results[index].tags.pop('__ordered_by_CAS__')
results[r[:, 'id']].embeddings = r.embeddings

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -345,10 +341,7 @@ async def aencode(self, content, **kwargs):
async for da in self._async_client.post(
**self._get_post_payload(content, results, kwargs)
):
for d in da:
index = int(d.tags['__ordered_by_CAS__'])
results[index].embedding = d.embedding
results[index].tags.pop('__ordered_by_CAS__')
results[da[:, 'id']].embeddings = da.embeddings

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -395,7 +388,6 @@ def _iter_rank_docs(
),
)

d.tags['__ordered_by_CAS__'] = i
results.append(d)
yield d

Expand All @@ -420,10 +412,7 @@ def _gather_rank_result(self, response, results: 'DocumentArray'):
from rich import filesize

r = response.data.docs
for d in r:
index = int(d.tags['__ordered_by_CAS__'])
results[index].matches = d.matches
results[index].tags.pop('__ordered_by_CAS__')
results[r[:, 'id']][:, 'matches'] = r[:, 'matches']

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down Expand Up @@ -513,10 +502,7 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
async for da in self._async_client.post(
**self._get_rank_payload(docs, results, kwargs)
):
for d in da:
index = int(d.tags['__ordered_by_CAS__'])
results[index].matches = d.matches
results[index].tags.pop('__ordered_by_CAS__')
results[da[:, 'id']][:, 'matches'] = da[:, 'matches']

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
Expand Down

0 comments on commit 4d41716

Please sign in to comment.