From bb520d14b6c5172fce9a971b51c4125b60418119 Mon Sep 17 00:00:00 2001 From: felix-wang <35718120+numb3r3@users.noreply.github.com> Date: Mon, 9 May 2022 18:22:08 +0800 Subject: [PATCH] fix: keep logit_scale on same device (#710) * fix: keep logit_scale on cpu * fix: use cuda when computing ranked score --- server/clip_server/executors/clip_torch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 3ce03298a..ff64d1c0e 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -88,8 +88,12 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): else: _img_da = await self.encode(_img_da) _txt_da = await self.encode(_txt_da) - _img_da.embeddings = torch.from_numpy(_img_da.embeddings) - _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings) + _img_da.embeddings = torch.from_numpy(_img_da.embeddings).to( + self._device, non_blocking=True + ) + _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings).to( + self._device, non_blocking=True + ) # normalized features image_features = _img_da.embeddings / _img_da.embeddings.norm(