From 5697441bca859a56d986b6c540ddb1fae0d3b258 Mon Sep 17 00:00:00 2001 From: Jem Date: Thu, 25 Jul 2019 17:16:45 +0800 Subject: [PATCH] feat(indexer): consider offset relevance at query time --- gnes/indexer/vector/annoy.py | 8 +++--- gnes/indexer/vector/bindexer/__init__.py | 6 ++--- gnes/indexer/vector/faiss.py | 6 ++--- gnes/indexer/vector/hbindexer/__init__.py | 6 ++--- gnes/indexer/vector/numpy.py | 4 +-- gnes/service/indexer.py | 33 +++++++++++++++++++---- 6 files changed, 43 insertions(+), 20 deletions(-) diff --git a/gnes/indexer/vector/annoy.py b/gnes/indexer/vector/annoy.py index 587ad067..1e97ffc1 100644 --- a/gnes/indexer/vector/annoy.py +++ b/gnes/indexer/vector/annoy.py @@ -14,7 +14,7 @@ # limitations under the License. import os -from typing import List, Tuple +from typing import List, Tuple, Any import numpy as np @@ -44,7 +44,7 @@ def post_init(self): except: self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path) - def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs): + def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): last_idx = self._key_info_indexer.size if len(vectors) != len(keys): @@ -70,7 +70,7 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> List[List[Tu res.append([(*r, s) for r, s in zip(chunk_info, relevance_score)]) return res - def normalize_score(self, score: List[float], metrics: str, *args) -> List[float]: + def normalize_score(self, score: List[float], metrics: str, *args, **kwargs) -> List[float]: if metrics == 'angular': return list(map(lambda x:1 / (1 + x), score)) elif metrics == 'euclidean': @@ -81,7 +81,7 @@ def normalize_score(self, score: List[float], metrics: str, *args) -> List[float elif metrics == 'hamming': return list(map(lambda x:1 / (1 + x), score)) elif metrics == 'dot': - pass + raise NotImplementedError @property def size(self): diff --git a/gnes/indexer/vector/bindexer/__init__.py b/gnes/indexer/vector/bindexer/__init__.py index 26c47454..4e459648 100644 --- a/gnes/indexer/vector/bindexer/__init__.py +++ b/gnes/indexer/vector/bindexer/__init__.py @@ -16,7 +16,7 @@ # pylint: disable=low-comment-ratio import os -from typing import List, Tuple +from typing import List, Tuple, Any import numpy as np @@ -56,7 +56,7 @@ def post_init(self): except (FileNotFoundError, IsADirectoryError): self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path) - def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, + def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): if len(vectors) != len(keys): raise ValueError('vectors length should be equal to doc_ids') @@ -112,7 +112,7 @@ def query(self, result[q].append((i, o, w / self._weight_norm, self.normalize_score(d))) return result - def normalize_score(self, distance: int, *args) -> float: + def normalize_score(self, distance: int, *args, **kwargs) -> float: return 1. - distance / self.num_bytes def __getstate__(self): diff --git a/gnes/indexer/vector/faiss.py b/gnes/indexer/vector/faiss.py index 134dcb68..69b95079 100644 --- a/gnes/indexer/vector/faiss.py +++ b/gnes/indexer/vector/faiss.py @@ -17,7 +17,7 @@ import os -from typing import List, Tuple +from typing import List, Tuple, Any import numpy as np @@ -46,7 +46,7 @@ def post_init(self): self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path) self._faiss_index = faiss.index_factory(self.num_dim, self.index_key) - def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs): + def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): if len(vectors) != len(keys): raise ValueError("vectors length should be equal to doc_ids") @@ -72,7 +72,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tupl return ret - def normalize_score(self, score: np.ndarray, *args) -> np.ndarray: + def normalize_score(self, score: np.ndarray, *args, **kwargs) -> np.ndarray: if 'HNSW' in self.index_key: return 1 / (1 + np.sqrt(score) / self.num_dim) elif 'PQ' or 'Flat' in self.index_key: diff --git a/gnes/indexer/vector/hbindexer/__init__.py b/gnes/indexer/vector/hbindexer/__init__.py index 73e6dac8..c4d3726d 100644 --- a/gnes/indexer/vector/hbindexer/__init__.py +++ b/gnes/indexer/vector/hbindexer/__init__.py @@ -16,7 +16,7 @@ # pylint: disable=low-comment-ratio import os -from typing import List, Tuple +from typing import List, Tuple, Any import numpy as np @@ -53,7 +53,7 @@ def post_init(self): except (FileNotFoundError, IsADirectoryError): self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path) - def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs): + def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): if len(vectors) != len(keys): raise ValueError("vectors length should be equal to doc_ids") @@ -92,7 +92,7 @@ def query(self, return [sorted(ret.items(), key=lambda x: -x[1])[:top_k] for ret in result] - def normalize_score(self, distance: int, *args) -> float: + def normalize_score(self, distance: int, *args, **kwargs) -> float: return 1. - distance / self.n_bytes * 8 def __getstate__(self): diff --git a/gnes/indexer/vector/numpy.py b/gnes/indexer/vector/numpy.py index 38bbf1b8..b92c6326 100644 --- a/gnes/indexer/vector/numpy.py +++ b/gnes/indexer/vector/numpy.py @@ -15,7 +15,7 @@ # pylint: disable=low-comment-ratio -from typing import List, Tuple +from typing import List, Tuple, Any import numpy as np @@ -31,7 +31,7 @@ def __init__(self, num_bytes: int = None, *args, **kwargs): self._vectors = None # type: np.ndarray self._key_info_indexer = ListKeyIndexer() - def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, + def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): if len(vectors) % len(keys) != 0: raise ValueError('vectors bytes should be divided by doc_ids') diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index 1b5bd3da..2c134cb1 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -38,7 +38,12 @@ def _handler_index(self, msg: 'gnes_pb2.Message'): for d in msg.request.index.docs: all_vecs.append(blob2array(d.chunk_embeddings)) doc_ids += [d.doc_id] * len(d.chunks) - offsets += [c.offset_1d for c in d.chunks] + if msg.request.index.docs.doc_type == 'TEXT': + offsets += [c.offset_1d for c in d.chunks] + elif msg.request.index.docs.doc_type == 'IMAGE': + offsets += [c.offset_nd for c in d.chunks] + elif msg.request.index.docs.doc_type == 'VIDEO': + offsets += [c.offset_1d for c in d.chunks] weights += [c.weight for c in d.chunks] from ..indexer.base import BaseVectorIndexer, BaseTextIndexer @@ -56,22 +61,40 @@ def _handler_index(self, msg: 'gnes_pb2.Message'): @handler.register(gnes_pb2.Request.QueryRequest) def _handler_chunk_search(self, msg: 'gnes_pb2.Message'): + def _cal_offset_relevance(q_offset, i_offset): + import math + if not isinstance(q_offset, list) and isinstance(i_offset, list): + raise TypeError("Type of qc_offset and offset is supposed to be (list, list), " + "but actually we got (%s, %s)" % (str(type(q_offset)), str(type(i_offset)))) + if not len(q_offset) == 2 and len(i_offset) == 2: + raise ValueError("Length of qc_offset and offset should be (2, 2), " + "but actually we got (%d, %d)" % (len(q_offset), len(i_offset))) + return 1 / (1 + math.sqrt((q_offset[0] - i_offset[0])**2 + (q_offset[1] - i_offset[1])**2)) + vecs = blob2array(msg.request.search.query.chunk_embeddings) + q_offset = [c.offset_nd if msg.request.search.query.doc_type == 'IMAGE' + else c.offset_1d for c in msg.request.search.query.chunks] topk = msg.request.search.top_k results = self._model.query(vecs, top_k=msg.request.search.top_k) q_weights = [qc.weight for qc in msg.request.search.query.chunks] - for all_topks, qc_weight in zip(results, q_weights): + for all_topks, qc_weight, qc_offset in zip(results, q_weights, q_offset): for _doc_id, _offset, _weight, _relevance in all_topks: r = msg.response.search.topk_results.add() r.chunk.doc_id = _doc_id - r.chunk.offset_1d = _offset r.chunk.weight = _weight - r.score = _weight * qc_weight * _relevance + if msg.request.search.query.doc_type == 'IMAGE': + r.chunk.offset_nd = _offset + offset_relevance = _cal_offset_relevance(qc_offset, _offset) + else: + r.chunk.offset_1d = _offset + offset_relevance = 1 + r.score = _weight * qc_weight * _relevance * offset_relevance r.score_explained = '[chunk_score at doc: %d, offset: %d] = ' \ '(doc_chunk_weight: %.6f) * ' \ '(query_doc_chunk_relevance: %.6f) * ' \ + '(query_doc_offset_relevance: %.6f) * ' \ '(query_chunk_weight: %.6f)' % ( - _doc_id, _offset, _weight, _relevance, qc_weight) + _doc_id, _offset, _weight, _relevance, offset_relevance, qc_weight) msg.response.search.top_k = topk @handler.register(gnes_pb2.Response.QueryResponse)