From 2d6c70fc389d2928714576ef1d58afb5805ce838 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Mon, 2 Sep 2019 16:29:37 +0800 Subject: [PATCH] fix(indexer): fix vec np.concat --- gnes/indexer/base.py | 2 +- gnes/service/indexer.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/gnes/indexer/base.py b/gnes/indexer/base.py index 9763048e..ba1b4eed 100644 --- a/gnes/indexer/base.py +++ b/gnes/indexer/base.py @@ -51,7 +51,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tupl def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, **kwargs) -> List[ 'gnes_pb2.Response.QueryResponse.ScoredResult']: vecs = [blob2array(c.embedding) for c in q_chunks] - queried_results = self.query(np.concatenate(vecs, 0), top_k=top_k) + queried_results = self.query(np.stack(vecs), top_k=top_k) results = [] for q_chunk, topk_chunks in zip(q_chunks, queried_results): for _doc_id, _offset, _weight, _relevance in topk_chunks: diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index e401df33..8550d951 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -48,17 +48,15 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'): self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id) continue - for c in d.chunks: - self.logger.info(c.embedding) vecs += [blob2array(c.embedding) for c in d.chunks] doc_ids += [d.doc_id] * len(d.chunks) offsets += [c.offset for c in d.chunks] weights += [c.weight for c in d.chunks] - self.logger.info('%d %d %d %d' % (len(vecs), len(doc_ids), len(offsets), len(weights))) - self.logger.info(np.concatenate(vecs, 0).shape) + # self.logger.info('%d %d %d %d' % (len(vecs), len(doc_ids), len(offsets), len(weights))) + # self.logger.info(np.stack(vecs).shape) if vecs: - self._model.add(list(zip(doc_ids, offsets)), np.concatenate(vecs, 0), weights) + self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights) def _handler_doc_index(self, msg: 'gnes_pb2.Message'): self._model.add([d.doc_id for d in msg.request.index.docs],