Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(score): improve score explain for better interpretability
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 29, 2019
1 parent 42e7c13 commit 07534f8
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 64 deletions.
69 changes: 60 additions & 9 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
from typing import List, Any, Union, Callable, Tuple

import numpy as np
Expand Down Expand Up @@ -66,7 +65,7 @@ def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, *

def score(self, q_chunk: 'gnes_pb2.Chunk', d_chunk: 'gnes_pb2.Chunk',
relevance) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError
return ChunkScorer.eq1(q_chunk, d_chunk, relevance)


class BaseDocIndexer(BaseIndexer):
Expand All @@ -77,20 +76,21 @@ def add(self, keys: List[int], docs: Any, weights: List[float], *args, **kwargs)
def query(self, keys: List[int], *args, **kwargs) -> List['gnes_pb2.Document']:
pass

def query_and_score(self, keys: List[int], *args, **kwargs) -> List[
def query_and_score(self, docs: List['gnes_pb2.Response.QueryResponse.ScoredResult'], *args, **kwargs) -> List[
'gnes_pb2.Response.QueryResponse.ScoredResult']:
keys = [r.doc.doc_id for r in docs]
results = []
queried_results = self.query(keys, *args, **kwargs)
for d in queried_results:
r = gnes_pb2.Response.QueryResponse.ScoredResult()
for d, r in zip(queried_results, docs):
if d:
r.doc.CopyFrom(d)
r.score.CopyFrom(self.score(d))
r.score.CopyFrom(self.score(d, r.score))
results.append(r)
return results

def score(self, d: 'gnes_pb2.Document', *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
raise NotImplementedError
def score(self, d: 'gnes_pb2.Document', s: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', *args,
**kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
return DocScorer.eq1(d, s)


class BaseKeyIndexer(BaseIndexer):
Expand All @@ -102,6 +102,57 @@ def query(self, keys: List[int], *args, **kwargs) -> List[Tuple[int, int, float]
pass


class ChunkScorer:

@staticmethod
def eq1(q_chunk: 'gnes_pb2.Chunk', d_chunk: 'gnes_pb2.Chunk',
relevance):
"""
score = d_chunk.weight * relevance * q_chunk.weight
"""
score = gnes_pb2.Response.QueryResponse.ScoredResult.Score()
score.value = d_chunk.weight * relevance * q_chunk.weight
score.explained = json.dumps({
'name': 'chunk-eq1',
'operand': [{'name': 'd_chunk_weight',
'value': d_chunk.weight,
'doc_id': d_chunk.doc_id,
'offset': d_chunk.offset},
{'name': 'q_chunk_weight',
'value': q_chunk.weight,
'offset': q_chunk.offset},
{'name': 'relevance',
'value': relevance}],
'op': 'prod',
'value': score.value
})
return score


class DocScorer:

@staticmethod
def eq1(d: 'gnes_pb2.Document',
s: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score') -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
"""
score *= d.weight
:param d:
:param s:
:return:
"""
s.value *= d.weight
s.explained = json.dumps({
'name': 'doc-eq1',
'operand': [json.loads(s.explained),
{'name': 'doc_weight',
'value': d.weight,
'doc_id': d.doc_id}],
'op': 'prod',
'value': s.value
})
return s


class JointIndexer(CompositionalTrainableBase):

@property
Expand Down
9 changes: 7 additions & 2 deletions gnes/router/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections import defaultdict
from functools import reduce
from typing import List, Generator
Expand Down Expand Up @@ -102,7 +102,12 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *
for k, v in sorted(score_dict.items(), key=lambda kv: kv[1]['reduced_value'] * (-1 if self.descending else 1)):
r = msg.response.search.topk_results.add()
r.score.value = v['reduced_value']
r.score.explained = ','.join('{%s}' % v['explains'])
r.score.explained = json.dumps({
'name': 'topk-reduce',
'op': self._reduce_op,
'operand': [json.loads(vv) for vv in v['explains']],
'value': r.score.value
})
self.set_key(r, k)

super().apply(msg, accum_msgs)
Expand Down
10 changes: 2 additions & 8 deletions gnes/router/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Generator

from .base import BaseMapRouter
Expand All @@ -22,17 +21,12 @@


class SortedTopkRouter(BaseMapRouter):
def __init__(self, descending: bool = True, top_k: int = None, *args, **kwargs):
def __init__(self, descending: bool = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.descending = descending
self.top_k = top_k

def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs):
# resort all doc result as the doc_weight has been applied
final_docs = copy.deepcopy(
sorted(msg.response.search.topk_results, key=lambda x: x.score, reverse=self.descending))
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(final_docs[:(self.top_k or msg.response.search.top_k)])
msg.response.search.topk_results.sort(key=lambda x: x.score.value, reverse=self.descending)


class PublishRouter(BaseMapRouter):
Expand Down
11 changes: 7 additions & 4 deletions gnes/router/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *
# get result from all shards, some may return None, we only take the first non-None doc
final_docs.append([m.response.search.topk_results[idx] for m in accum_msgs if
m.response.search.topk_results[idx].doc.WhichOneof('raw_data') is not None][0])

# resort all doc result as the doc_weight has been applied
final_docs = sorted(final_docs, key=lambda x: x.score, reverse=True)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(final_docs[:msg.response.search.top_k])
msg.response.search.topk_results.extend(final_docs)

super().apply(msg, accum_msgs)


class DocTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their doc_id, result in a topk doc list
"""
def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
return x.doc.doc_id

Expand All @@ -45,6 +45,9 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):


class ChunkTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their chunk_id, aka doc_id-offset, result in a topk chunk list
"""
def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
return '%d-%d' % (x.chunk.doc_id, x.chunk.offset)

Expand Down
Empty file removed gnes/scorer/__init__.py
Empty file.
38 changes: 0 additions & 38 deletions gnes/scorer/base.py

This file was deleted.

3 changes: 2 additions & 1 deletion gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
return embeds

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
Expand All @@ -61,7 +62,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'):
if msg.request.train.docs:
_, contents = self.get_chunks_from_docs(msg.request.train.docs)
_, contents = self.embed_chunks_in_docs(msg.request.train.docs)
self.train_data.extend(contents)
msg.response.train.status = gnes_pb2.Response.PENDING
# raise BlockMessage
Expand Down
3 changes: 1 addition & 2 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)

doc_ids = [r.doc.doc_id for r in msg.response.search.topk_results]
results = self._model.query_and_score(doc_ids)
results = self._model.query_and_score(msg.response.search.topk_results)
msg.response.search.ClearField('topk_results')
msg.response.search.topk_results.extend(results)

0 comments on commit 07534f8

Please sign in to comment.