Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(indexer): add counter and move key_only to chunk helper
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 11, 2019
1 parent 2c2a4d8 commit 421e21f
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 249 deletions.
2 changes: 1 addition & 1 deletion gnes/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
BaseChunkIndexer = indexer_base.BaseChunkIndexer
BaseIndexer = indexer_base.BaseIndexer
BaseDocIndexer = indexer_base.BaseDocIndexer
BaseKeyIndexer = indexer_base.BaseKeyIndexer
BaseKeyIndexer = indexer_base.BaseChunkIndexerHelper
JointIndexer = indexer_base.JointIndexer

# Preprocessor
Expand Down
84 changes: 75 additions & 9 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import List, Any, Union, Callable, Tuple

import numpy as np
Expand All @@ -37,7 +38,7 @@ def __init__(self,
self.normalize_fn = normalize_fn
self.score_fn = score_fn
self.is_big_score_similar = is_big_score_similar
self._num_doc = 0
self._num_docs = 0
self._num_chunks = 0

def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs):
Expand All @@ -50,13 +51,29 @@ def query_and_score(self, q_chunks: List[Union['gnes_pb2.Chunk', 'gnes_pb2.Docum
'gnes_pb2.Response.QueryResponse.ScoredResult']:
raise NotImplementedError

def update_counter(self, *args, **kwargs):
pass
@property
def num_docs(self):
return self._num_docs

@property
def num_chunks(self):
return self._num_chunks


class BaseChunkIndexer(BaseIndexer):
"""Storing chunks and their vector representations """

def __init__(self, helper_indexer: 'BaseChunkIndexerHelper' = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.helper_indexer = helper_indexer

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
"""
adding new chunks and their vector representations
:param keys: list of (doc_id, offset) tuple
:param vectors: vector representations
:param weights: weight of the chunks
"""
pass

def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
Expand Down Expand Up @@ -87,13 +104,52 @@ def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, *
results.append(r)
return results

def update_counter(self, keys: List[Tuple[int, int]], *args, **kwargs):
pass
@staticmethod
def update_counter(func):
@wraps(func)
def arg_wrapper(self, keys: List[Tuple[int, int]], *args, **kwargs):
doc_ids, _ = zip(*keys)
self._num_docs += len(set(doc_ids))
self._num_chunks += len(keys)
return func(self, keys, *args, **kwargs)

return arg_wrapper

@staticmethod
def update_helper_indexer(func):
@wraps(func)
def arg_wrapper(self, keys: List[Tuple[int, int]], *args, **kwargs):
r = func(self, keys, *args, **kwargs)
if self.helper_indexer:
self.helper_indexer.add(keys, *args, **kwargs)
return r

return arg_wrapper

@property
def num_docs(self):
if self.helper_indexer:
return self.helper_indexer._num_docs
else:
return self._num_docs

@property
def num_chunks(self):
if self.helper_indexer:
return self.helper_indexer._num_chunks
else:
return self._num_chunks


class BaseDocIndexer(BaseIndexer):
"""Storing documents and contents """

def add(self, keys: List[int], docs: Any, weights: List[float], *args, **kwargs):
def add(self, keys: List[int], docs: List['gnes_pb2.Document'], *args, **kwargs):
"""
adding new docs and their protobuf representation
:param keys: list of doc_id
:param docs: list of protobuf Document objects
"""
pass

def query(self, keys: List[int], *args, **kwargs) -> List['gnes_pb2.Document']:
Expand All @@ -113,11 +169,21 @@ def query_and_score(self, docs: List['gnes_pb2.Response.QueryResponse.ScoredResu
results.append(r)
return results

def update_counter(self, docs: List['gnes_pb2.Document'], *args, **kwargs):
pass
@staticmethod
def update_counter(func):
@wraps(func)
def arg_wrapper(self, keys: List[int], docs: List['gnes_pb2.Document'], *args, **kwargs):
self._num_docs += len(keys)
self._num_chunks += sum(len(d.chunks) for d in docs)
return func(self, keys, docs, *args, **kwargs)

return arg_wrapper


class BaseKeyIndexer(BaseIndexer):
class BaseChunkIndexerHelper(BaseChunkIndexer):
"""A helper class for storing chunk info, doc mapping, weights.
This is especially useful when ChunkIndexer can not store these information by itself
"""

def add(self, keys: List[Tuple[int, int]], weights: List[float], *args, **kwargs) -> int:
pass
Expand Down
27 changes: 7 additions & 20 deletions gnes/indexer/chunk/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@

import numpy as np

from ..base import BaseChunkIndexer
from ..key_only import ListKeyIndexer
from .helper import ListKeyIndexer
from ..base import BaseChunkIndexer as BCI


class AnnoyIndexer(BaseChunkIndexer):
class AnnoyIndexer(BCI):

def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees=10, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_dim = num_dim
self.data_path = data_path
self.metric = metric
self.n_trees = n_trees
self._key_info_indexer = ListKeyIndexer()
self.helper_indexer = ListKeyIndexer()

def post_init(self):
from annoy import AnnoyIndex
Expand All @@ -44,8 +44,9 @@ def post_init(self):
except:
self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)

@BCI.update_helper_indexer
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
last_idx = self._key_info_indexer.num_chunks
last_idx = self.helper_indexer.num_chunks

if len(vectors) != len(keys):
raise ValueError('vectors length should be equal to doc_ids')
Expand All @@ -56,31 +57,17 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
for idx, vec in enumerate(vectors):
self._index.add_item(last_idx + idx, vec)

self._key_info_indexer.add(keys, weights)

def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> List[List[Tuple]]:
self._index.build(self.n_trees)
if keys.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')
res = []
for k in keys:
ret, relevance_score = self._index.get_nns_by_vector(k, top_k, include_distances=True)
chunk_info = self._key_info_indexer.query(ret)
chunk_info = self.helper_indexer.query(ret)
res.append([(*r, s) for r, s in zip(chunk_info, relevance_score)])
return res

@property
def num_chunks(self):
return self._key_info_indexer.num_chunks

@property
def num_doc(self):
return self._key_info_indexer.num_doc

@property
def num_chunks_avg(self):
return self._key_info_indexer.num_chunks_avg

def __getstate__(self):
d = super().__getstate__()
self._index.save(self.data_path)
Expand Down
30 changes: 5 additions & 25 deletions gnes/indexer/chunk/bindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import numpy as np

from .cython import IndexCore
from ...base import BaseChunkIndexer
from ..helper import ListKeyIndexer
from ...base import BaseChunkIndexer as BCI


class BIndexer(BaseChunkIndexer):
class BIndexer(BCI):

def __init__(self,
num_bytes: int = None,
Expand All @@ -39,7 +40,7 @@ def __init__(self,
self.insert_iterations = insert_iterations
self.query_iterations = query_iterations
self.data_path = data_path
self._all_docs = []
self.helper_indexer = ListKeyIndexer()

def post_init(self):
self.bindexer = IndexCore(self.num_bytes, 4, self.ef,
Expand All @@ -54,6 +55,7 @@ def post_init(self):
except (FileNotFoundError, IsADirectoryError):
self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)

@BCI.update_helper_indexer
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args,
**kwargs):
if len(vectors) != len(keys):
Expand All @@ -62,7 +64,6 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
if vectors.dtype != np.uint8:
raise ValueError('vectors should be ndarray of uint8')

self.update_counter(keys)
num_rows = len(keys)
keys, offsets = zip(*keys)
keys = np.array(keys, dtype=np.uint32).tobytes()
Expand Down Expand Up @@ -120,27 +121,6 @@ def query(self,
result[q].append((i, o, self.uint2float_weight(w), d))
return result

def update_counter(self, keys: List[Tuple[int, Any]], *args, **kwargs):
self._update_docs(keys)
self._num_chunks += len(keys)

def _update_docs(self, keys: List[Tuple[int, Any]]):
for key in keys:
if key[0] not in self._all_docs:
self._all_docs.append(key[0])

@property
def num_doc(self):
return len(self._all_docs)

@property
def num_chunks(self):
return self._num_chunks

@property
def num_chunks_avg(self):
return self._num_chunks / len(self._all_docs)

def __getstate__(self):
self.bindexer.save(self.data_path)
d = super().__getstate__()
Expand Down
24 changes: 6 additions & 18 deletions gnes/indexer/chunk/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@

import numpy as np

from ..base import BaseChunkIndexer
from ..key_only import ListKeyIndexer
from .helper import ListKeyIndexer
from ..base import BaseChunkIndexer as BCI


class FaissIndexer(BaseChunkIndexer):
class FaissIndexer(BCI):

def __init__(self, num_dim: int, index_key: str, data_path: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_path = data_path
self.num_dim = num_dim
self.index_key = index_key
self._key_info_indexer = ListKeyIndexer()
self.helper_indexer = ListKeyIndexer()

def post_init(self):
import faiss
Expand All @@ -44,14 +44,14 @@ def post_init(self):
self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)

@BCI.update_helper_indexer
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
if len(vectors) != len(keys):
raise ValueError("vectors length should be equal to doc_ids")

if vectors.dtype != np.float32:
raise ValueError("vectors should be ndarray of float32")

self._key_info_indexer.add(keys, weights)
self._faiss_index.add(vectors)

def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
Expand All @@ -62,25 +62,13 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tupl
ret = []
for _id, _score in zip(ids, score):
ret_i = []
chunk_info = self._key_info_indexer.query(_id)
chunk_info = self.helper_indexer.query(_id)
for c_info, _score_i in zip(chunk_info, _score):
ret_i.append((*c_info, _score_i))
ret.append(ret_i)

return ret

@property
def num_chunks(self):
return self._key_info_indexer.num_chunks

@property
def num_doc(self):
return self._key_info_indexer.num_doc

@property
def num_chunks_avg(self):
return self._key_info_indexer.num_chunks_avg

def __getstate__(self):
import faiss
d = super().__getstate__()
Expand Down
Loading

0 comments on commit 421e21f

Please sign in to comment.