Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #205 from gnes-ai/feat-sort-logic
Browse files Browse the repository at this point in the history
feat(index): move sort logic out to base
  • Loading branch information
mergify[bot] authored Sep 4, 2019
2 parents 62b3c4a + ac15f4d commit 15c1dd6
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 56 deletions.
22 changes: 13 additions & 9 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def _set_loadable_service_parser(parser=None):
return parser


def _set_sortable_service_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_loadable_service_parser(parser)

parser.add_argument('--sorted_response', action='store_true', default=False,
help='sort the response (if exist) by the score')
return parser


# shortcut to keep consistent
set_encoder_parser = _set_loadable_service_parser

Expand All @@ -200,25 +210,19 @@ def set_preprocessor_parser(parser=None):
def set_router_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_loadable_service_parser(parser)
_set_sortable_service_parser(parser)

parser.add_argument('--num_part', type=int, default=None,
help='explicitly set the number of parts of message')
parser.set_defaults(read_only=True)
return parser


def set_indexer_parser(parser=None):
from ..service.base import SocketType

if not parser:
parser = set_base_parser()
_set_loadable_service_parser(parser)
_set_sortable_service_parser(parser)

# encoder's port_out is indexer's port_in
parser.set_defaults(port_in=parser.get_default('port_out'),
port_out=parser.get_default('port_out') + 2,
socket_in=SocketType.PULL_CONNECT,
socket_out=SocketType.PUB_BIND)
return parser


Expand Down
20 changes: 18 additions & 2 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@
class BaseIndexer(TrainableBase):
def __init__(self,
normalize_fn: 'BaseScoreFn' = ModifierScoreFn(),
score_fn: 'BaseScoreFn' = ModifierScoreFn(), *args, **kwargs):
score_fn: 'BaseScoreFn' = ModifierScoreFn(),
is_big_score_similar: bool = False,
*args, **kwargs):
"""
Base indexer, a valid indexer must implement `add` and `query` methods
:type score_fn: advanced score function
:type normalize_fn: normalizing score function
:type is_big_score_similar: when set to true, then larger score means more similar
"""
super().__init__(*args, **kwargs)
self.normalize_fn = normalize_fn
self.score_fn = score_fn
self.is_big_score_similar = is_big_score_similar

def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs):
pass
Expand Down Expand Up @@ -59,7 +68,14 @@ def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, *
r.chunk.doc_id = _doc_id
r.chunk.offset = _offset
r.chunk.weight = _weight
_score = get_unary_score(value=_relevance, name=self.__class__.__name__)
_score = get_unary_score(value=_relevance,
name=self.__class__.__name__,
operands=[
dict(name='doc_chunk',
doc_id=_doc_id,
offset=_offset),
dict(name='query_chunk',
offset=q_chunk.offset)])
_score = self.normalize_fn(_score)
_score = self.score_fn(_score, q_chunk, r.chunk)
r.score.CopyFrom(_score)
Expand Down
19 changes: 13 additions & 6 deletions gnes/indexer/chunk/bindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(self,
self.ef = ef
self.insert_iterations = insert_iterations
self.query_iterations = query_iterations

self.data_path = data_path
self._weight_norm = 2 ** 16 - 1

def post_init(self):
self.bindexer = IndexCore(self.num_bytes, 4, self.ef,
Expand All @@ -67,9 +65,18 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
keys, offsets = zip(*keys)
keys = np.array(keys, dtype=np.uint32).tobytes()
offsets = np.array(offsets, dtype=np.uint16).tobytes()
weights = np.array([w * self._weight_norm for w in weights], dtype=np.uint16).tobytes()
weights = self.float2uint_weight(weights).tobytes()
self.bindexer.index_trie(vectors.tobytes(), num_rows, keys, offsets, weights)

@staticmethod
def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1):
weights = norm * np.array(weights)
return np.array(weights, dtype=np.uint16)

@staticmethod
def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1):
return weight / norm

def query(self,
keys: np.ndarray,
top_k: int,
Expand All @@ -91,15 +98,15 @@ def query(self,
q_idx, doc_ids, offsets, weights = self.bindexer.find_batch_trie(
keys, num_rows)
for (i, q, o, w) in zip(doc_ids, q_idx, offsets, weights):
result[q].append((i, o, w / self._weight_norm, 1))
result[q].append((i, o, self.uint2float_weight(w), 0))

# search the indexed items with similar value
doc_ids, offsets, weights, dists, q_idx = self.bindexer.nsw_search(
keys, num_rows, top_k)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
if d == 0:
continue
result[q].append((i, o, w / self._weight_norm, d))
result[q].append((i, o, self.uint2float_weight(w), d))

# get the top-k
for q in range(num_rows):
Expand All @@ -108,7 +115,7 @@ def query(self,
doc_ids, offsets, weights, dists, q_idx = self.bindexer.force_search(
keys, num_rows, top_k)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
result[q].append((i, o, w / self._weight_norm, d))
result[q].append((i, o, self.uint2float_weight(w), d))
return result

def __getstate__(self):
Expand Down
14 changes: 11 additions & 3 deletions gnes/indexer/chunk/hbindexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self,
self.n_clusters = num_clusters
self.n_idx = n_idx
self.data_path = data_path
self._weight_norm = 2 ** 16 - 1
if self.n_idx <= 0:
raise ValueError('There should be at least 1 clustering slot')

Expand All @@ -63,11 +62,20 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
keys, offsets = zip(*keys)
keys = np.array(keys, dtype=np.uint32).tobytes()
offsets = np.array(offsets, dtype=np.uint16).tobytes()
weights = np.array(weights * self._weight_norm, dtype=np.uint16).tobytes()
weights = self.float2uint_weight(weights).tobytes()
clusters = vectors[:, :self.n_idx].tobytes()
vectors = vectors[:, self.n_idx:].astype(np.uint8).tobytes()
self.hbindexer.index_trie(vectors, clusters, keys, offsets, weights, n)

@staticmethod
def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1):
weights = norm * np.array(weights)
return np.array(weights, dtype=np.uint16)

@staticmethod
def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1):
return weight / norm

def query(self,
vectors: np.ndarray,
top_k: int,
Expand All @@ -87,7 +95,7 @@ def query(self,
doc_ids, offsets, weights, dists, q_idx = self.hbindexer.query(
vectors, clusters, n, top_k * self.n_idx)
for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx):
result[q][(i, o, w / self._weight_norm)] = d
result[q][(i, o, self.uint2float_weight(w))] = d

return [list(ret.items()) for ret in result]

Expand Down
2 changes: 1 addition & 1 deletion gnes/indexer/chunk/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs
) -> List[List[Tuple]]:
keys = np.expand_dims(keys, axis=1)
dist = keys - np.expand_dims(self._vectors, axis=0)
score = 1 - np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes
score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes

ret = []
for ids in score:
Expand Down
3 changes: 2 additions & 1 deletion gnes/proto/gnes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ message Response {
Status status = 1;
uint32 top_k = 2;
repeated ScoredResult topk_results = 3;

bool is_big_score_similar = 4;
bool is_sorted = 5;
}
}

Expand Down
36 changes: 25 additions & 11 deletions gnes/proto/gnes_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion gnes/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
'DocFillReducer': 'reduce',
'PublishRouter': 'map',
'DocBatchRouter': 'map',
'SortedTopkRouter': 'map',
}

register_all_class(_cls2file_map, 'router')
Loading

0 comments on commit 15c1dd6

Please sign in to comment.