Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(indexer): consider offset relevance at query time
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin committed Jul 25, 2019
1 parent c0bffe6 commit 5697441
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 20 deletions.
8 changes: 4 additions & 4 deletions gnes/indexer/vector/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import os
from typing import List, Tuple
from typing import List, Tuple, Any

import numpy as np

Expand Down Expand Up @@ -44,7 +44,7 @@ def post_init(self):
except:
self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
last_idx = self._key_info_indexer.size

if len(vectors) != len(keys):
Expand All @@ -70,7 +70,7 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> List[List[Tu
res.append([(*r, s) for r, s in zip(chunk_info, relevance_score)])
return res

def normalize_score(self, score: List[float], metrics: str, *args) -> List[float]:
def normalize_score(self, score: List[float], metrics: str, *args, **kwargs) -> List[float]:
if metrics == 'angular':
return list(map(lambda x:1 / (1 + x), score))
elif metrics == 'euclidean':
Expand All @@ -81,7 +81,7 @@ def normalize_score(self, score: List[float], metrics: str, *args) -> List[float
elif metrics == 'hamming':
return list(map(lambda x:1 / (1 + x), score))
elif metrics == 'dot':
pass
raise NotImplementedError

@property
def size(self):
Expand Down
6 changes: 3 additions & 3 deletions gnes/indexer/vector/bindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=low-comment-ratio

import os
from typing import List, Tuple
from typing import List, Tuple, Any

import numpy as np

Expand Down Expand Up @@ -56,7 +56,7 @@ def post_init(self):
except (FileNotFoundError, IsADirectoryError):
self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args,
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args,
**kwargs):
if len(vectors) != len(keys):
raise ValueError('vectors length should be equal to doc_ids')
Expand Down Expand Up @@ -112,7 +112,7 @@ def query(self,
result[q].append((i, o, w / self._weight_norm, self.normalize_score(d)))
return result

def normalize_score(self, distance: int, *args) -> float:
def normalize_score(self, distance: int, *args, **kwargs) -> float:
return 1. - distance / self.num_bytes

def __getstate__(self):
Expand Down
6 changes: 3 additions & 3 deletions gnes/indexer/vector/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import os
from typing import List, Tuple
from typing import List, Tuple, Any

import numpy as np

Expand Down Expand Up @@ -46,7 +46,7 @@ def post_init(self):
self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
if len(vectors) != len(keys):
raise ValueError("vectors length should be equal to doc_ids")

Expand All @@ -72,7 +72,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tupl

return ret

def normalize_score(self, score: np.ndarray, *args) -> np.ndarray:
def normalize_score(self, score: np.ndarray, *args, **kwargs) -> np.ndarray:
if 'HNSW' in self.index_key:
return 1 / (1 + np.sqrt(score) / self.num_dim)
elif 'PQ' or 'Flat' in self.index_key:
Expand Down
6 changes: 3 additions & 3 deletions gnes/indexer/vector/hbindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=low-comment-ratio

import os
from typing import List, Tuple
from typing import List, Tuple, Any

import numpy as np

Expand Down Expand Up @@ -53,7 +53,7 @@ def post_init(self):
except (FileNotFoundError, IsADirectoryError):
self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
if len(vectors) != len(keys):
raise ValueError("vectors length should be equal to doc_ids")

Expand Down Expand Up @@ -92,7 +92,7 @@ def query(self,

return [sorted(ret.items(), key=lambda x: -x[1])[:top_k] for ret in result]

def normalize_score(self, distance: int, *args) -> float:
def normalize_score(self, distance: int, *args, **kwargs) -> float:
return 1. - distance / self.n_bytes * 8

def __getstate__(self):
Expand Down
4 changes: 2 additions & 2 deletions gnes/indexer/vector/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# pylint: disable=low-comment-ratio

from typing import List, Tuple
from typing import List, Tuple, Any

import numpy as np

Expand All @@ -31,7 +31,7 @@ def __init__(self, num_bytes: int = None, *args, **kwargs):
self._vectors = None # type: np.ndarray
self._key_info_indexer = ListKeyIndexer()

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args,
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args,
**kwargs):
if len(vectors) % len(keys) != 0:
raise ValueError('vectors bytes should be divided by doc_ids')
Expand Down
33 changes: 28 additions & 5 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
for d in msg.request.index.docs:
all_vecs.append(blob2array(d.chunk_embeddings))
doc_ids += [d.doc_id] * len(d.chunks)
offsets += [c.offset_1d for c in d.chunks]
if msg.request.index.docs.doc_type == 'TEXT':
offsets += [c.offset_1d for c in d.chunks]
elif msg.request.index.docs.doc_type == 'IMAGE':
offsets += [c.offset_nd for c in d.chunks]
elif msg.request.index.docs.doc_type == 'VIDEO':
offsets += [c.offset_1d for c in d.chunks]
weights += [c.weight for c in d.chunks]

from ..indexer.base import BaseVectorIndexer, BaseTextIndexer
Expand All @@ -56,22 +61,40 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):

@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
def _cal_offset_relevance(q_offset, i_offset):
import math
if not isinstance(q_offset, list) and isinstance(i_offset, list):
raise TypeError("Type of qc_offset and offset is supposed to be (list, list), "
"but actually we got (%s, %s)" % (str(type(q_offset)), str(type(i_offset))))
if not len(q_offset) == 2 and len(i_offset) == 2:
raise ValueError("Length of qc_offset and offset should be (2, 2), "
"but actually we got (%d, %d)" % (len(q_offset), len(i_offset)))
return 1 / (1 + math.sqrt((q_offset[0] - i_offset[0])**2 + (q_offset[1] - i_offset[1])**2))

vecs = blob2array(msg.request.search.query.chunk_embeddings)
q_offset = [c.offset_nd if msg.request.search.query.doc_type == 'IMAGE'
else c.offset_1d for c in msg.request.search.query.chunks]
topk = msg.request.search.top_k
results = self._model.query(vecs, top_k=msg.request.search.top_k)
q_weights = [qc.weight for qc in msg.request.search.query.chunks]
for all_topks, qc_weight in zip(results, q_weights):
for all_topks, qc_weight, qc_offset in zip(results, q_weights, q_offset):
for _doc_id, _offset, _weight, _relevance in all_topks:
r = msg.response.search.topk_results.add()
r.chunk.doc_id = _doc_id
r.chunk.offset_1d = _offset
r.chunk.weight = _weight
r.score = _weight * qc_weight * _relevance
if msg.request.search.query.doc_type == 'IMAGE':
r.chunk.offset_nd = _offset
offset_relevance = _cal_offset_relevance(qc_offset, _offset)
else:
r.chunk.offset_1d = _offset
offset_relevance = 1
r.score = _weight * qc_weight * _relevance * offset_relevance
r.score_explained = '[chunk_score at doc: %d, offset: %d] = ' \
'(doc_chunk_weight: %.6f) * ' \
'(query_doc_chunk_relevance: %.6f) * ' \
'(query_doc_offset_relevance: %.6f) * ' \
'(query_chunk_weight: %.6f)' % (
_doc_id, _offset, _weight, _relevance, qc_weight)
_doc_id, _offset, _weight, _relevance, offset_relevance, qc_weight)
msg.response.search.top_k = topk

@handler.register(gnes_pb2.Response.QueryResponse)
Expand Down

0 comments on commit 5697441

Please sign in to comment.