From cd53a24bfb708adf4c7df6b673a1515550fe6022 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Thu, 5 Sep 2019 16:53:28 +0800 Subject: [PATCH] fix(indexer): fix numpy indexer --- gnes/indexer/chunk/numpy.py | 31 ++++++++++++++++++------------- gnes/service/base.py | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/gnes/indexer/chunk/numpy.py b/gnes/indexer/chunk/numpy.py index e700e72b..7b6254ba 100644 --- a/gnes/indexer/chunk/numpy.py +++ b/gnes/indexer/chunk/numpy.py @@ -23,11 +23,15 @@ class NumpyIndexer(BaseChunkIndexer): + """An exhaustive search indexer using numpy + The distance is computed as L1 distance normalized by the number of dimension + """ - def __init__(self, num_bytes: int = None, *args, **kwargs): + def __init__(self, is_binary: bool = False, *args, **kwargs): super().__init__() - self.num_bytes = num_bytes + self._num_dim = None self._vectors = None # type: np.ndarray + self._is_binary = is_binary self._key_info_indexer = ListKeyIndexer() def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, @@ -35,12 +39,12 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl if len(vectors) % len(keys) != 0: raise ValueError('vectors bytes should be divided by doc_ids') - if not self.num_bytes: - self.num_bytes = vectors.shape[1] - elif self.num_bytes != vectors.shape[1]: + if not self._num_dim: + self._num_dim = vectors.shape[1] + elif self._num_dim != vectors.shape[1]: raise ValueError( "vectors' shape [%d, %d] does not match with indexer's dim: %d" % - (vectors.shape[0], vectors.shape[1], self.num_bytes)) + (vectors.shape[0], vectors.shape[1], self._num_dim)) if self._vectors is not None: self._vectors = np.concatenate([self._vectors, vectors], axis=0) @@ -48,16 +52,17 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl self._vectors = vectors self._key_info_indexer.add(keys, weights) - def query(self, keys: np.ndarray, top_k: int, *args, **kwargs - ) -> List[List[Tuple]]: - keys = np.expand_dims(keys, axis=1) - dist = keys - np.expand_dims(self._vectors, axis=0) - score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes + def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]: + dist = np.abs(np.expand_dims(keys, axis=1) - np.expand_dims(self._vectors, axis=0)) + + if self._is_binary: + dist = np.minimum(dist, 1) + + score = np.sum(dist, -1) / self._num_dim ret = [] for ids in score: - rk = sorted(enumerate(ids), key=lambda x: -x[1]) + rk = sorted(enumerate(ids), key=lambda x: x[1])[:top_k] chunk_info = self._key_info_indexer.query([j[0] for j in rk]) - ret.append([(*r, s) for r, s in zip(chunk_info, [j[1] for j in rk])]) return ret diff --git a/gnes/service/base.py b/gnes/service/base.py index 10dcf1a8..34d696ea 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -267,7 +267,7 @@ def post_handler(self, msg: 'gnes_pb2.Message', *args, **kwargs): def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs): new_type = msg.WhichOneof('body') if new_type != old_body_type: - self.logger.warning('message body is changed from %s to %s' % (new_type, old_body_type)) + self.logger.warning('message body is changed from %s to %s' % (old_body_type, new_type)) def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs): if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results: