From a2d55dda809d5ad82635cc599862b5e5d6f2f186 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 4 Sep 2019 17:35:37 +0800 Subject: [PATCH 1/3] feat(index): move sort logic out to base --- gnes/cli/parser.py | 5 ++++- gnes/indexer/base.py | 20 +++++++++++++++++-- gnes/indexer/chunk/bindexer/__init__.py | 19 ++++++++++++------ gnes/indexer/chunk/hbindexer/__init__.py | 14 ++++++++++--- gnes/indexer/chunk/numpy.py | 2 +- gnes/proto/gnes.proto | 3 ++- gnes/router/__init__.py | 1 - gnes/router/base.py | 9 +++------ gnes/router/map.py | 9 --------- gnes/router/reduce.py | 13 ++++++++++-- gnes/service/base.py | 12 ++++++++++++ gnes/service/indexer.py | 25 +++++++++++++++++++----- 12 files changed, 95 insertions(+), 37 deletions(-) diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index abc2a1dc..f3f6a5f2 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -203,6 +203,8 @@ def set_router_parser(parser=None): _set_loadable_service_parser(parser) parser.add_argument('--num_part', type=int, default=None, help='explicitly set the number of parts of message') + parser.add_argument('--sort_response', type=bool, default=True, + help='sort the response (if exist) by the score') parser.set_defaults(read_only=True) return parser @@ -213,7 +215,8 @@ def set_indexer_parser(parser=None): if not parser: parser = set_base_parser() _set_loadable_service_parser(parser) - + parser.add_argument('--sort_response', type=bool, default=True, + help='sort the response (if exist) by the score') # encoder's port_out is indexer's port_in parser.set_defaults(port_in=parser.get_default('port_out'), port_out=parser.get_default('port_out') + 2, diff --git a/gnes/indexer/base.py b/gnes/indexer/base.py index c00760e5..b1625dd0 100644 --- a/gnes/indexer/base.py +++ b/gnes/indexer/base.py @@ -24,10 +24,19 @@ class BaseIndexer(TrainableBase): def __init__(self, normalize_fn: 'BaseScoreFn' = ModifierScoreFn(), - score_fn: 'BaseScoreFn' = ModifierScoreFn(), *args, **kwargs): + score_fn: 'BaseScoreFn' = ModifierScoreFn(), + is_big_score_similar: bool = False, + *args, **kwargs): + """ + Base indexer, a valid indexer must implement `add` and `query` methods + :type score_fn: advanced score function + :type normalize_fn: normalizing score function + :type is_big_score_similar: when set to true, then larger score means more similar + """ super().__init__(*args, **kwargs) self.normalize_fn = normalize_fn self.score_fn = score_fn + self.is_big_score_similar = is_big_score_similar def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs): pass @@ -59,7 +68,14 @@ def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, * r.chunk.doc_id = _doc_id r.chunk.offset = _offset r.chunk.weight = _weight - _score = get_unary_score(value=_relevance, name=self.__class__.__name__) + _score = get_unary_score(value=_relevance, + name=self.__class__.__name__, + operands=[ + dict(name='doc_chunk', + doc_id=_doc_id, + offset=_offset), + dict(name='query_chunk', + offset=q_chunk.offset)]) _score = self.normalize_fn(_score) _score = self.score_fn(_score, q_chunk, r.chunk) r.score.CopyFrom(_score) diff --git a/gnes/indexer/chunk/bindexer/__init__.py b/gnes/indexer/chunk/bindexer/__init__.py index 9b25d985..5cde707d 100644 --- a/gnes/indexer/chunk/bindexer/__init__.py +++ b/gnes/indexer/chunk/bindexer/__init__.py @@ -38,9 +38,7 @@ def __init__(self, self.ef = ef self.insert_iterations = insert_iterations self.query_iterations = query_iterations - self.data_path = data_path - self._weight_norm = 2 ** 16 - 1 def post_init(self): self.bindexer = IndexCore(self.num_bytes, 4, self.ef, @@ -67,9 +65,18 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl keys, offsets = zip(*keys) keys = np.array(keys, dtype=np.uint32).tobytes() offsets = np.array(offsets, dtype=np.uint16).tobytes() - weights = np.array([w * self._weight_norm for w in weights], dtype=np.uint16).tobytes() + weights = self.float2uint_weight(weights).tobytes() self.bindexer.index_trie(vectors.tobytes(), num_rows, keys, offsets, weights) + @staticmethod + def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1): + weights = norm * np.array(weights) + return np.array(weights, dtype=np.uint16) + + @staticmethod + def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1): + return weight / norm + def query(self, keys: np.ndarray, top_k: int, @@ -91,7 +98,7 @@ def query(self, q_idx, doc_ids, offsets, weights = self.bindexer.find_batch_trie( keys, num_rows) for (i, q, o, w) in zip(doc_ids, q_idx, offsets, weights): - result[q].append((i, o, w / self._weight_norm, 1)) + result[q].append((i, o, self.uint2float_weight(w), 0)) # search the indexed items with similar value doc_ids, offsets, weights, dists, q_idx = self.bindexer.nsw_search( @@ -99,7 +106,7 @@ def query(self, for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx): if d == 0: continue - result[q].append((i, o, w / self._weight_norm, d)) + result[q].append((i, o, self.uint2float_weight(w), d)) # get the top-k for q in range(num_rows): @@ -108,7 +115,7 @@ def query(self, doc_ids, offsets, weights, dists, q_idx = self.bindexer.force_search( keys, num_rows, top_k) for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx): - result[q].append((i, o, w / self._weight_norm, d)) + result[q].append((i, o, self.uint2float_weight(w), d)) return result def __getstate__(self): diff --git a/gnes/indexer/chunk/hbindexer/__init__.py b/gnes/indexer/chunk/hbindexer/__init__.py index 377f6c3f..23b95e79 100644 --- a/gnes/indexer/chunk/hbindexer/__init__.py +++ b/gnes/indexer/chunk/hbindexer/__init__.py @@ -37,7 +37,6 @@ def __init__(self, self.n_clusters = num_clusters self.n_idx = n_idx self.data_path = data_path - self._weight_norm = 2 ** 16 - 1 if self.n_idx <= 0: raise ValueError('There should be at least 1 clustering slot') @@ -63,11 +62,20 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl keys, offsets = zip(*keys) keys = np.array(keys, dtype=np.uint32).tobytes() offsets = np.array(offsets, dtype=np.uint16).tobytes() - weights = np.array(weights * self._weight_norm, dtype=np.uint16).tobytes() + weights = self.float2uint_weight(weights).tobytes() clusters = vectors[:, :self.n_idx].tobytes() vectors = vectors[:, self.n_idx:].astype(np.uint8).tobytes() self.hbindexer.index_trie(vectors, clusters, keys, offsets, weights, n) + @staticmethod + def float2uint_weight(weights: List[float], norm: int = 2 ** 16 - 1): + weights = norm * np.array(weights) + return np.array(weights, dtype=np.uint16) + + @staticmethod + def uint2float_weight(weight: int, norm: int = 2 ** 16 - 1): + return weight / norm + def query(self, vectors: np.ndarray, top_k: int, @@ -87,7 +95,7 @@ def query(self, doc_ids, offsets, weights, dists, q_idx = self.hbindexer.query( vectors, clusters, n, top_k * self.n_idx) for (i, o, w, d, q) in zip(doc_ids, offsets, weights, dists, q_idx): - result[q][(i, o, w / self._weight_norm)] = d + result[q][(i, o, self.uint2float_weight(w))] = d return [list(ret.items()) for ret in result] diff --git a/gnes/indexer/chunk/numpy.py b/gnes/indexer/chunk/numpy.py index 72c07d3f..e700e72b 100644 --- a/gnes/indexer/chunk/numpy.py +++ b/gnes/indexer/chunk/numpy.py @@ -52,7 +52,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs ) -> List[List[Tuple]]: keys = np.expand_dims(keys, axis=1) dist = keys - np.expand_dims(self._vectors, axis=0) - score = 1 - np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes + score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes ret = [] for ids in score: diff --git a/gnes/proto/gnes.proto b/gnes/proto/gnes.proto index 86d77247..5758ec0b 100644 --- a/gnes/proto/gnes.proto +++ b/gnes/proto/gnes.proto @@ -186,7 +186,8 @@ message Response { Status status = 1; uint32 top_k = 2; repeated ScoredResult topk_results = 3; - + bool is_big_score_similar = 4; + bool is_sorted = 5; } } diff --git a/gnes/router/__init__.py b/gnes/router/__init__.py index d9e3d687..75de2cc1 100644 --- a/gnes/router/__init__.py +++ b/gnes/router/__init__.py @@ -27,7 +27,6 @@ 'DocFillReducer': 'reduce', 'PublishRouter': 'map', 'DocBatchRouter': 'map', - 'SortedTopkRouter': 'map', } register_all_class(_cls2file_map, 'router') diff --git a/gnes/router/base.py b/gnes/router/base.py index a9a84827..a9967595 100644 --- a/gnes/router/base.py +++ b/gnes/router/base.py @@ -59,10 +59,9 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], * class BaseTopkReduceRouter(BaseReduceRouter): - def __init__(self, reduce_op: str = 'sum', descending: bool = True, *args, **kwargs): + def __init__(self, reduce_op: str = 'sum', *args, **kwargs): super().__init__(*args, **kwargs) self._reduce_op = reduce_op - self.descending = descending def post_init(self): self.reduce_op = CombinedScoreFn(score_mode=self._reduce_op) @@ -80,16 +79,14 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], * # count score by iterating over chunks for c in all_scored_results: - k = self.get_key(c) - score_dict[k].append(c.score) + score_dict[self.get_key(c)].append(c.score) for k, v in score_dict.items(): score_dict[k] = self.reduce_op(*v) msg.response.search.ClearField('topk_results') - # sort and add docs - for k, v in sorted(score_dict.items(), key=lambda kv: kv[1].value, reverse=self.descending): + for k, v in score_dict.items(): r = msg.response.search.topk_results.add() r.score.CopyFrom(v) self.set_key(r, k) diff --git a/gnes/router/map.py b/gnes/router/map.py index b2216ca1..af2c7b16 100644 --- a/gnes/router/map.py +++ b/gnes/router/map.py @@ -20,15 +20,6 @@ from ..proto import gnes_pb2 -class SortedTopkRouter(BaseMapRouter): - def __init__(self, descending: bool = True, *args, **kwargs): - super().__init__(*args, **kwargs) - self.descending = descending - - def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs): - msg.response.search.topk_results.sort(key=lambda x: x.score.value, reverse=self.descending) - - class PublishRouter(BaseMapRouter): def __init__(self, num_part: int, *args, **kwargs): diff --git a/gnes/router/reduce.py b/gnes/router/reduce.py index ea2dd115..65eccc49 100644 --- a/gnes/router/reduce.py +++ b/gnes/router/reduce.py @@ -19,6 +19,13 @@ class DocFillReducer(BaseReduceRouter): + """ + Gather all documents raw content from multiple shards. + This is only useful when you have + - multiple doc-indexer and docs are spreaded over multiple shards. + - require full-doc retrieval with the original content, not just an doc id + Ideally, only each doc can only belong to one shard. + """ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs): final_docs = [] for idx in range(len(msg.response.search.topk_results)): @@ -45,7 +52,9 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str): class Chunk2DocTopkReducer(BaseTopkReduceRouter): """ - Gather all chunks by their doc_id, result in a topk doc list + Gather all chunks by their doc_id, result in a topk doc list. + This is almost always useful, as the final result should be group by doc_id + not chunk """ def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str: @@ -57,7 +66,7 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str): class ChunkTopkReducer(BaseTopkReduceRouter): """ - Gather all chunks by their chunk_id, aka doc_id-offset, result in a topk chunk list + Gather all chunks by their chunk_id from all shards, aka doc_id-offset, result in a topk chunk list """ def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str: diff --git a/gnes/service/base.py b/gnes/service/base.py index b649d482..1ee5d259 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -257,6 +257,16 @@ def dump(self): else: self.logger.info('no dumping as "read_only" set to true.') + def post_handler(self, msg: 'gnes_pb2.Message'): + if 'sort_result' in self.args.sort_result and self.args.sort_result and msg.response.search.topk_results: + msg.response.search.topk_results.sort(key=lambda x: x.score.value, + reverse=msg.response.search.is_big_score_similar) + + msg.response.search.is_sorted = True + self.logger.info('sorted %d results in %s order' % + (len(msg.response.search.topk_results), + 'descending' if msg.response.search.is_big_score_similar else 'ascending')) + def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck): try: fn = self.handler.serve(msg) @@ -273,9 +283,11 @@ def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck): ret = fn(self, msg) if ret is None: # assume 'msg' is modified inside fn() + self.post_handler(msg) send_message(out_sock, msg, timeout=self.args.timeout) elif isinstance(ret, types.GeneratorType): for r_msg in ret: + self.post_handler(r_msg) send_message(out_sock, r_msg, timeout=self.args.timeout) else: raise ServiceError('unknown return type from the handler: %s' % fn) diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index 627d2d12..1c6778cd 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -63,6 +63,12 @@ def _handler_doc_index(self, msg: 'gnes_pb2.Message'): [d for d in msg.request.index.docs], [d.weight for d in msg.request.index.docs]) + def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'): + msg.response.search.ClearField('topk_results') + msg.response.search.topk_results.extend(results) + msg.response.search.top_k = len(results) + msg.response.search.is_big_score_similar = self._model.is_big_score_similar + @handler.register(gnes_pb2.Request.QueryRequest) def _handler_chunk_search(self, msg: 'gnes_pb2.Message'): from ..indexer.base import BaseChunkIndexer @@ -70,10 +76,10 @@ def _handler_chunk_search(self, msg: 'gnes_pb2.Message'): raise ServiceError( 'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__) + # assume the chunk search will change the whatever sort order the message has + msg.response.search.is_sorted = False results = self._model.query_and_score(msg.request.search.query.chunks, top_k=msg.request.search.top_k) - msg.response.search.ClearField('topk_results') - msg.response.search.topk_results.extend(results) - msg.response.search.top_k = len(results) + self._put_result_into_message(results, msg) @handler.register(gnes_pb2.Response.QueryResponse) def _handler_doc_search(self, msg: 'gnes_pb2.Message'): @@ -82,6 +88,15 @@ def _handler_doc_search(self, msg: 'gnes_pb2.Message'): raise ServiceError( 'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__) + # check if chunk_indexer and doc_indexer has the same sorting order + if msg.response.search.is_big_score_similar is not None and \ + msg.response.search.is_big_score_similar != self._model.is_big_score_similar: + raise ServiceError( + 'is_big_score_similar is inconsistent. last topk-list: is_big_score_similar=%s, but ' + 'this indexer: is_big_score_similar=%s' % ( + msg.response.search.is_big_score_similar, self._model.is_big_score_similar)) + + # assume the doc search will change the whatever sort order the message has + msg.response.search.is_sorted = False results = self._model.query_and_score(msg.response.search.topk_results) - msg.response.search.ClearField('topk_results') - msg.response.search.topk_results.extend(results) + self._put_result_into_message(results, msg) From 81b210930e3d4248ad077728f4d983568ab51934 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 4 Sep 2019 17:36:04 +0800 Subject: [PATCH 2/3] feat(index): move sort logic to base --- gnes/proto/gnes_pb2.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/gnes/proto/gnes_pb2.py b/gnes/proto/gnes_pb2.py index 642f8ea7..c6a644f1 100644 --- a/gnes/proto/gnes_pb2.py +++ b/gnes/proto/gnes_pb2.py @@ -21,7 +21,7 @@ package='gnes', syntax='proto3', serialized_options=None, - serialized_pb=_b('\n\ngnes.proto\x12\x04gnes\x1a\x1fgoogle/protobuf/timestamp.proto\"9\n\x07NdArray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x11\n\x05shape\x18\x02 \x03(\rB\x02\x10\x01\x12\r\n\x05\x64type\x18\x03 \x01(\t\"\xb9\x01\n\x05\x43hunk\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x0e\n\x04text\x18\x02 \x01(\tH\x00\x12\x1d\n\x04\x62lob\x18\x03 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\r\n\x03raw\x18\x07 \x01(\x0cH\x00\x12\x0e\n\x06offset\x18\x04 \x01(\r\x12\x15\n\toffset_nd\x18\x05 \x03(\rB\x02\x10\x01\x12\x0e\n\x06weight\x18\x06 \x01(\x02\x12 \n\tembedding\x18\x08 \x01(\x0b\x32\r.gnes.NdArrayB\t\n\x07\x63ontent\"\xc4\x02\n\x08\x44ocument\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x1b\n\x06\x63hunks\x18\x02 \x03(\x0b\x32\x0b.gnes.Chunk\x12(\n\x08\x64oc_type\x18\x03 \x01(\x0e\x32\x16.gnes.Document.DocType\x12\x11\n\tmeta_info\x18\x04 \x01(\x0c\x12\x12\n\x08raw_text\x18\x05 \x01(\tH\x00\x12\"\n\traw_image\x18\x06 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\"\n\traw_video\x18\x07 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x13\n\traw_bytes\x18\x08 \x01(\x0cH\x00\x12\x0e\n\x06weight\x18\n \x01(\x02\"A\n\x07\x44ocType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04TEXT\x10\x01\x12\t\n\x05IMAGE\x10\x02\x12\t\n\x05VIDEO\x10\x03\x12\t\n\x05\x41UDIO\x10\x04\x42\n\n\x08raw_data\"\xd4\x01\n\x08\x45nvelope\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\nrequest_id\x18\x02 \x01(\r\x12\x0f\n\x07part_id\x18\x03 \x01(\r\x12\x10\n\x08num_part\x18\x04 \x03(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12$\n\x06routes\x18\x06 \x03(\x0b\x32\x14.gnes.Envelope.route\x1aG\n\x05route\x12\x0f\n\x07service\x18\x01 \x01(\t\x12-\n\ttimestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"y\n\x07Message\x12 \n\x08\x65nvelope\x18\x01 \x01(\x0b\x32\x0e.gnes.Envelope\x12 \n\x07request\x18\x02 \x01(\x0b\x32\r.gnes.RequestH\x00\x12\"\n\x08response\x18\x03 \x01(\x0b\x32\x0e.gnes.ResponseH\x00\x42\x06\n\x04\x62ody\"\xf6\x03\n\x07Request\x12\x12\n\nrequest_id\x18\x01 \x01(\r\x12+\n\x05train\x18\x02 \x01(\x0b\x32\x1a.gnes.Request.TrainRequestH\x00\x12+\n\x05index\x18\x03 \x01(\x0b\x32\x1a.gnes.Request.IndexRequestH\x00\x12,\n\x06search\x18\x04 \x01(\x0b\x32\x1a.gnes.Request.QueryRequestH\x00\x12/\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1c.gnes.Request.ControlRequestH\x00\x1a;\n\x0cTrainRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x12\r\n\x05\x66lush\x18\x02 \x01(\x08\x1a,\n\x0cIndexRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x1a<\n\x0cQueryRequest\x12\x1d\n\x05query\x18\x01 \x01(\x0b\x32\x0e.gnes.Document\x12\r\n\x05top_k\x18\x02 \x01(\r\x1am\n\x0e\x43ontrolRequest\x12\x35\n\x07\x63ommand\x18\x01 \x01(\x0e\x32$.gnes.Request.ControlRequest.Command\"$\n\x07\x43ommand\x12\r\n\tTERMINATE\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x06\n\x04\x62ody\"\x8a\x06\n\x08Response\x12\x12\n\nrequest_id\x18\x01 \x01(\r\x12-\n\x05train\x18\x02 \x01(\x0b\x32\x1c.gnes.Response.TrainResponseH\x00\x12-\n\x05index\x18\x03 \x01(\x0b\x32\x1c.gnes.Response.IndexResponseH\x00\x12.\n\x06search\x18\x04 \x01(\x0b\x32\x1c.gnes.Response.QueryResponseH\x00\x12\x31\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1e.gnes.Response.ControlResponseH\x00\x1a\x36\n\rTrainResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x36\n\rIndexResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x38\n\x0f\x43ontrolResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\xc7\x02\n\rQueryResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x12\r\n\x05top_k\x18\x02 \x01(\r\x12?\n\x0ctopk_results\x18\x03 \x03(\x0b\x32).gnes.Response.QueryResponse.ScoredResult\x1a\xbe\x01\n\x0cScoredResult\x12\x1c\n\x05\x63hunk\x18\x01 \x01(\x0b\x32\x0b.gnes.ChunkH\x00\x12\x1d\n\x03\x64oc\x18\x02 \x01(\x0b\x32\x0e.gnes.DocumentH\x00\x12>\n\x05score\x18\x03 \x01(\x0b\x32/.gnes.Response.QueryResponse.ScoredResult.Score\x1a)\n\x05Score\x12\r\n\x05value\x18\x01 \x01(\x02\x12\x11\n\texplained\x18\x02 \x01(\tB\x06\n\x04\x62ody\"-\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x06\n\x04\x62ody2\xe3\x01\n\x07GnesRPC\x12(\n\x05Train\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Index\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Query\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\'\n\x04\x43\x61ll\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\x31\n\nStreamCall\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00(\x01\x30\x01\x62\x06proto3') + serialized_pb=_b('\n\ngnes.proto\x12\x04gnes\x1a\x1fgoogle/protobuf/timestamp.proto\"9\n\x07NdArray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x11\n\x05shape\x18\x02 \x03(\rB\x02\x10\x01\x12\r\n\x05\x64type\x18\x03 \x01(\t\"\xb9\x01\n\x05\x43hunk\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x0e\n\x04text\x18\x02 \x01(\tH\x00\x12\x1d\n\x04\x62lob\x18\x03 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\r\n\x03raw\x18\x07 \x01(\x0cH\x00\x12\x0e\n\x06offset\x18\x04 \x01(\r\x12\x15\n\toffset_nd\x18\x05 \x03(\rB\x02\x10\x01\x12\x0e\n\x06weight\x18\x06 \x01(\x02\x12 \n\tembedding\x18\x08 \x01(\x0b\x32\r.gnes.NdArrayB\t\n\x07\x63ontent\"\xc4\x02\n\x08\x44ocument\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x1b\n\x06\x63hunks\x18\x02 \x03(\x0b\x32\x0b.gnes.Chunk\x12(\n\x08\x64oc_type\x18\x03 \x01(\x0e\x32\x16.gnes.Document.DocType\x12\x11\n\tmeta_info\x18\x04 \x01(\x0c\x12\x12\n\x08raw_text\x18\x05 \x01(\tH\x00\x12\"\n\traw_image\x18\x06 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\"\n\traw_video\x18\x07 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x13\n\traw_bytes\x18\x08 \x01(\x0cH\x00\x12\x0e\n\x06weight\x18\n \x01(\x02\"A\n\x07\x44ocType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04TEXT\x10\x01\x12\t\n\x05IMAGE\x10\x02\x12\t\n\x05VIDEO\x10\x03\x12\t\n\x05\x41UDIO\x10\x04\x42\n\n\x08raw_data\"\xd4\x01\n\x08\x45nvelope\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\nrequest_id\x18\x02 \x01(\r\x12\x0f\n\x07part_id\x18\x03 \x01(\r\x12\x10\n\x08num_part\x18\x04 \x03(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12$\n\x06routes\x18\x06 \x03(\x0b\x32\x14.gnes.Envelope.route\x1aG\n\x05route\x12\x0f\n\x07service\x18\x01 \x01(\t\x12-\n\ttimestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"y\n\x07Message\x12 \n\x08\x65nvelope\x18\x01 \x01(\x0b\x32\x0e.gnes.Envelope\x12 \n\x07request\x18\x02 \x01(\x0b\x32\r.gnes.RequestH\x00\x12\"\n\x08response\x18\x03 \x01(\x0b\x32\x0e.gnes.ResponseH\x00\x42\x06\n\x04\x62ody\"\xf6\x03\n\x07Request\x12\x12\n\nrequest_id\x18\x01 \x01(\r\x12+\n\x05train\x18\x02 \x01(\x0b\x32\x1a.gnes.Request.TrainRequestH\x00\x12+\n\x05index\x18\x03 \x01(\x0b\x32\x1a.gnes.Request.IndexRequestH\x00\x12,\n\x06search\x18\x04 \x01(\x0b\x32\x1a.gnes.Request.QueryRequestH\x00\x12/\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1c.gnes.Request.ControlRequestH\x00\x1a;\n\x0cTrainRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x12\r\n\x05\x66lush\x18\x02 \x01(\x08\x1a,\n\x0cIndexRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x1a<\n\x0cQueryRequest\x12\x1d\n\x05query\x18\x01 \x01(\x0b\x32\x0e.gnes.Document\x12\r\n\x05top_k\x18\x02 \x01(\r\x1am\n\x0e\x43ontrolRequest\x12\x35\n\x07\x63ommand\x18\x01 \x01(\x0e\x32$.gnes.Request.ControlRequest.Command\"$\n\x07\x43ommand\x12\r\n\tTERMINATE\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x06\n\x04\x62ody\"\xbb\x06\n\x08Response\x12\x12\n\nrequest_id\x18\x01 \x01(\r\x12-\n\x05train\x18\x02 \x01(\x0b\x32\x1c.gnes.Response.TrainResponseH\x00\x12-\n\x05index\x18\x03 \x01(\x0b\x32\x1c.gnes.Response.IndexResponseH\x00\x12.\n\x06search\x18\x04 \x01(\x0b\x32\x1c.gnes.Response.QueryResponseH\x00\x12\x31\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1e.gnes.Response.ControlResponseH\x00\x1a\x36\n\rTrainResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x36\n\rIndexResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x38\n\x0f\x43ontrolResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\xf8\x02\n\rQueryResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x12\r\n\x05top_k\x18\x02 \x01(\r\x12?\n\x0ctopk_results\x18\x03 \x03(\x0b\x32).gnes.Response.QueryResponse.ScoredResult\x12\x1c\n\x14is_big_score_similar\x18\x04 \x01(\x08\x12\x11\n\tis_sorted\x18\x05 \x01(\x08\x1a\xbe\x01\n\x0cScoredResult\x12\x1c\n\x05\x63hunk\x18\x01 \x01(\x0b\x32\x0b.gnes.ChunkH\x00\x12\x1d\n\x03\x64oc\x18\x02 \x01(\x0b\x32\x0e.gnes.DocumentH\x00\x12>\n\x05score\x18\x03 \x01(\x0b\x32/.gnes.Response.QueryResponse.ScoredResult.Score\x1a)\n\x05Score\x12\r\n\x05value\x18\x01 \x01(\x02\x12\x11\n\texplained\x18\x02 \x01(\tB\x06\n\x04\x62ody\"-\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x06\n\x04\x62ody2\xe3\x01\n\x07GnesRPC\x12(\n\x05Train\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Index\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Query\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\'\n\x04\x43\x61ll\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\x31\n\nStreamCall\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00(\x01\x30\x01\x62\x06proto3') , dependencies=[google_dot_protobuf_dot_timestamp__pb2.DESCRIPTOR,]) @@ -104,8 +104,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=2196, - serialized_end=2241, + serialized_start=2245, + serialized_end=2290, ) _sym_db.RegisterEnumDescriptor(_RESPONSE_STATUS) @@ -800,8 +800,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=2145, - serialized_end=2186, + serialized_start=2194, + serialized_end=2235, ) _RESPONSE_QUERYRESPONSE_SCOREDRESULT = _descriptor.Descriptor( @@ -847,8 +847,8 @@ name='body', full_name='gnes.Response.QueryResponse.ScoredResult.body', index=0, containing_type=None, fields=[]), ], - serialized_start=2004, - serialized_end=2194, + serialized_start=2053, + serialized_end=2243, ) _RESPONSE_QUERYRESPONSE = _descriptor.Descriptor( @@ -879,6 +879,20 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_big_score_similar', full_name='gnes.Response.QueryResponse.is_big_score_similar', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_sorted', full_name='gnes.Response.QueryResponse.is_sorted', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -892,7 +906,7 @@ oneofs=[ ], serialized_start=1867, - serialized_end=2194, + serialized_end=2243, ) _RESPONSE = _descriptor.Descriptor( @@ -954,7 +968,7 @@ index=0, containing_type=None, fields=[]), ], serialized_start=1471, - serialized_end=2249, + serialized_end=2298, ) _CHUNK.fields_by_name['blob'].message_type = _NDARRAY @@ -1215,8 +1229,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=2252, - serialized_end=2479, + serialized_start=2301, + serialized_end=2528, methods=[ _descriptor.MethodDescriptor( name='Train', From 59fce1469877b1c781588dc626e8bbef039d24b2 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 4 Sep 2019 17:59:09 +0800 Subject: [PATCH 3/3] feat(cli): add --sorted_response as cli argument --- gnes/cli/parser.py | 27 ++++++++++++++------------- gnes/service/base.py | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index f3f6a5f2..6330ded5 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -185,6 +185,16 @@ def _set_loadable_service_parser(parser=None): return parser +def _set_sortable_service_parser(parser=None): + if not parser: + parser = set_base_parser() + _set_loadable_service_parser(parser) + + parser.add_argument('--sorted_response', action='store_true', default=False, + help='sort the response (if exist) by the score') + return parser + + # shortcut to keep consistent set_encoder_parser = _set_loadable_service_parser @@ -200,28 +210,19 @@ def set_preprocessor_parser(parser=None): def set_router_parser(parser=None): if not parser: parser = set_base_parser() - _set_loadable_service_parser(parser) + _set_sortable_service_parser(parser) + parser.add_argument('--num_part', type=int, default=None, help='explicitly set the number of parts of message') - parser.add_argument('--sort_response', type=bool, default=True, - help='sort the response (if exist) by the score') parser.set_defaults(read_only=True) return parser def set_indexer_parser(parser=None): - from ..service.base import SocketType - if not parser: parser = set_base_parser() - _set_loadable_service_parser(parser) - parser.add_argument('--sort_response', type=bool, default=True, - help='sort the response (if exist) by the score') - # encoder's port_out is indexer's port_in - parser.set_defaults(port_in=parser.get_default('port_out'), - port_out=parser.get_default('port_out') + 2, - socket_in=SocketType.PULL_CONNECT, - socket_out=SocketType.PUB_BIND) + _set_sortable_service_parser(parser) + return parser diff --git a/gnes/service/base.py b/gnes/service/base.py index 1ee5d259..7f1d9e81 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -258,7 +258,7 @@ def dump(self): self.logger.info('no dumping as "read_only" set to true.') def post_handler(self, msg: 'gnes_pb2.Message'): - if 'sort_result' in self.args.sort_result and self.args.sort_result and msg.response.search.topk_results: + if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results: msg.response.search.topk_results.sort(key=lambda x: x.score.value, reverse=msg.response.search.is_big_score_similar)