diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 3ce03298a..ff64d1c0e 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -88,8 +88,12 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): else: _img_da = await self.encode(_img_da) _txt_da = await self.encode(_txt_da) - _img_da.embeddings = torch.from_numpy(_img_da.embeddings) - _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings) + _img_da.embeddings = torch.from_numpy(_img_da.embeddings).to( + self._device, non_blocking=True + ) + _txt_da.embeddings = torch.from_numpy(_txt_da.embeddings).to( + self._device, non_blocking=True + ) # normalized features image_features = _img_da.embeddings / _img_da.embeddings.norm(