Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #216 from gnes-ai/fix-numpy-indexer
Browse files Browse the repository at this point in the history
fix(indexer): fix numpy indexer
  • Loading branch information
mergify[bot] authored Sep 5, 2019
2 parents 91762ff + cd53a24 commit d805b56
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
31 changes: 18 additions & 13 deletions gnes/indexer/chunk/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,46 @@


class NumpyIndexer(BaseChunkIndexer):
"""An exhaustive search indexer using numpy
The distance is computed as L1 distance normalized by the number of dimension
"""

def __init__(self, num_bytes: int = None, *args, **kwargs):
def __init__(self, is_binary: bool = False, *args, **kwargs):
super().__init__()
self.num_bytes = num_bytes
self._num_dim = None
self._vectors = None # type: np.ndarray
self._is_binary = is_binary
self._key_info_indexer = ListKeyIndexer()

def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args,
**kwargs):
if len(vectors) % len(keys) != 0:
raise ValueError('vectors bytes should be divided by doc_ids')

if not self.num_bytes:
self.num_bytes = vectors.shape[1]
elif self.num_bytes != vectors.shape[1]:
if not self._num_dim:
self._num_dim = vectors.shape[1]
elif self._num_dim != vectors.shape[1]:
raise ValueError(
"vectors' shape [%d, %d] does not match with indexer's dim: %d" %
(vectors.shape[0], vectors.shape[1], self.num_bytes))
(vectors.shape[0], vectors.shape[1], self._num_dim))

if self._vectors is not None:
self._vectors = np.concatenate([self._vectors, vectors], axis=0)
else:
self._vectors = vectors
self._key_info_indexer.add(keys, weights)

def query(self, keys: np.ndarray, top_k: int, *args, **kwargs
) -> List[List[Tuple]]:
keys = np.expand_dims(keys, axis=1)
dist = keys - np.expand_dims(self._vectors, axis=0)
score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes
def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
dist = np.abs(np.expand_dims(keys, axis=1) - np.expand_dims(self._vectors, axis=0))

if self._is_binary:
dist = np.minimum(dist, 1)

score = np.sum(dist, -1) / self._num_dim

ret = []
for ids in score:
rk = sorted(enumerate(ids), key=lambda x: -x[1])
rk = sorted(enumerate(ids), key=lambda x: x[1])[:top_k]
chunk_info = self._key_info_indexer.query([j[0] for j in rk])

ret.append([(*r, s) for r, s in zip(chunk_info, [j[1] for j in rk])])
return ret
2 changes: 1 addition & 1 deletion gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def post_handler(self, msg: 'gnes_pb2.Message', *args, **kwargs):
def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs):
new_type = msg.WhichOneof('body')
if new_type != old_body_type:
self.logger.warning('message body is changed from %s to %s' % (new_type, old_body_type))
self.logger.warning('message body is changed from %s to %s' % (old_body_type, new_type))

def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs):
if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results:
Expand Down

0 comments on commit d805b56

Please sign in to comment.