Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(indexer): separate score logic and index logic
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 29, 2019
1 parent bae75b8 commit 42e7c13
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 40 deletions.
44 changes: 43 additions & 1 deletion gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, blob2array


class BaseIndexer(TrainableBase):
Expand All @@ -32,6 +33,13 @@ def query(self, keys: Any, *args, **kwargs) -> List[Any]:
def normalize_score(self, *args, **kwargs):
pass

def query_and_score(self, q_chunks: List[Union['gnes_pb2.Chunk', 'gnes_pb2.Document']], top_k: int) -> List[
'gnes_pb2.Response.QueryResponse.ScoredResult']:
raise NotImplementedError

def score(self, *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError


class BaseChunkIndexer(BaseIndexer):

Expand All @@ -41,15 +49,49 @@ def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[fl
def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
pass

def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, **kwargs) -> List[
'gnes_pb2.Response.QueryResponse.ScoredResult']:
vecs = [blob2array(c.embedding) for c in q_chunks]
queried_results = self.query(np.concatenate(vecs, 0), top_k=top_k)
results = []
for q_chunk, topk_chunks in zip(q_chunks, queried_results):
for _doc_id, _offset, _weight, _relevance in topk_chunks:
r = gnes_pb2.Response.QueryResponse.ScoredResult()
r.chunk.doc_id = _doc_id
r.chunk.offset = _offset
r.chunk.weight = _weight
r.score.CopyFrom(self.score(q_chunk, r.chunk, _relevance))
results.append(r)
return results

def score(self, q_chunk: 'gnes_pb2.Chunk', d_chunk: 'gnes_pb2.Chunk',
relevance) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError


class BaseDocIndexer(BaseIndexer):

def add(self, keys: List[int], docs: Any, weights: List[float], *args, **kwargs):
pass

def query(self, keys: List[int], *args, **kwargs) -> List[Any]:
def query(self, keys: List[int], *args, **kwargs) -> List['gnes_pb2.Document']:
pass

def query_and_score(self, keys: List[int], *args, **kwargs) -> List[
'gnes_pb2.Response.QueryResponse.ScoredResult']:
results = []
queried_results = self.query(keys, *args, **kwargs)
for d in queried_results:
r = gnes_pb2.Response.QueryResponse.ScoredResult()
if d:
r.doc.CopyFrom(d)
r.score.CopyFrom(self.score(d))
results.append(r)
return results

def score(self, d: 'gnes_pb2.Document', *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError


class BaseKeyIndexer(BaseIndexer):

Expand Down
8 changes: 4 additions & 4 deletions gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def post_init(self):
self._model = self.load_model(BaseEncoder)
self.train_data = []

def embed_chunks_from_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document']):
if getattr(docs, 'doc_type', None) is not None:
def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document']):
if not isinstance(docs, list):
docs = [docs]

contents = []
Expand All @@ -56,7 +56,7 @@ def embed_chunks_from_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
self.embed_chunks_from_docs(msg.request.index.docs)
self.embed_chunks_in_docs(msg.request.index.docs)

@handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'):
Expand All @@ -74,4 +74,4 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):

@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_search(self, msg: 'gnes_pb2.Message'):
self.embed_chunks_from_docs(msg.request.search.query)
self.embed_chunks_in_docs(msg.request.search.query)
42 changes: 7 additions & 35 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class IndexerService(BS):
def post_init(self):
from ..indexer.base import BaseIndexer
self._model = self.load_model(BaseIndexer)
from ..scorer.base import BaseScorer
self._scorer = self.load_model(BaseScorer, yaml_path=self.args.scorer_yaml_path)

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
Expand Down Expand Up @@ -68,45 +66,19 @@ def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)

from ..scorer.base import BaseChunkScorer
if not isinstance(self._scorer, BaseChunkScorer):
raise ServiceError(
'unsupported scorer, dont know how to use %s to handle this message' % self._scorer.__bases__)

vecs = [blob2array(c.embedding) for c in msg.request.search.query.chunks]
topk = msg.request.search.top_k
results = self._model.query(np.concatenate(vecs, 0), top_k=msg.request.search.top_k)

for q_chunk, topk_chunks in zip(msg.request.search.query.chunks, results):
for _doc_id, _offset, _weight, _relevance in topk_chunks:
r = msg.response.search.topk_results.add()
r.chunk.doc_id = _doc_id
r.chunk.offset = _offset
r.chunk.weight = _weight
r.score = self._scorer.compute(q_chunk, r.chunk, _relevance)

msg.response.search.top_k = topk
results = self._model.query_and_score(msg.request.search.query.chunks, top_k=msg.request.search.top_k)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)
msg.response.search.top_k = len(results)

@handler.register(gnes_pb2.Response.QueryResponse)
def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
# if msg.response.search.level != gnes_pb2.Response.QueryResponse.DOCUMENT_NOT_FILLED:
# raise ServiceError('dont know how to handle QueryResponse at %s level' % msg.response.search.level)
from ..indexer.base import BaseDocIndexer
if not isinstance(self._model, BaseDocIndexer):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)

from ..scorer.base import BaseDocScorer
if not isinstance(self._scorer, BaseDocScorer):
raise ServiceError(
'unsupported scorer, dont know how to use %s to handle this message' % self._scorer.__bases__)

doc_ids = [r.doc.doc_id for r in msg.response.search.topk_results]
docs = self._model.query(doc_ids)
for r, d in zip(msg.response.search.topk_results, docs):
if d is not None:
# fill in the doc if this shard returns non-empty
r.doc.CopyFrom(d)
r.score = self._scorer.compute(d)

# msg.response.search.level = gnes_pb2.Response.QueryResponse.DOCUMENT
results = self._model.query_and_score(doc_ids)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)

0 comments on commit 42e7c13

Please sign in to comment.