diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index 342206ed..14bb238f 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -288,6 +288,9 @@ def set_indexer_parser(parser=None): if not parser: parser = set_base_parser() _set_sortable_service_parser(parser) + parser.add_argument('--as_response', type=ActionNoYes, default=True, + help='convert the message type from request to response after indexing. ' + 'turn it off if you want to chain other services after this index service.') return parser diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index f20edfc1..e9e9f7e9 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -36,16 +36,20 @@ def _handler_index(self, msg: 'gnes_pb2.Message'): # print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index)) from ..indexer.base import BaseChunkIndexer, BaseDocIndexer if isinstance(self._model, BaseChunkIndexer): - self._handler_chunk_index(msg) + is_changed = self._handler_chunk_index(msg) elif isinstance(self._model, BaseDocIndexer): - self._handler_doc_index(msg) + is_changed = self._handler_doc_index(msg) else: raise ServiceError( 'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__) - msg.response.index.status = gnes_pb2.Response.SUCCESS - self.is_model_changed.set() - def _handler_chunk_index(self, msg: 'gnes_pb2.Message'): + if self.args.as_response: + msg.response.index.status = gnes_pb2.Response.SUCCESS + + if is_changed: + self.is_model_changed.set() + + def _handler_chunk_index(self, msg: 'gnes_pb2.Message') -> bool: embed_info = [] for d in msg.request.index.docs: @@ -59,13 +63,19 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'): if embed_info: vecs, doc_ids, offsets, weights = zip(*embed_info) self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights) + return True else: self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing') - - def _handler_doc_index(self, msg: 'gnes_pb2.Message'): - self._model.add([d.doc_id for d in msg.request.index.docs], - [d for d in msg.request.index.docs], - [d.weight for d in msg.request.index.docs]) + return False + + def _handler_doc_index(self, msg: 'gnes_pb2.Message') -> bool: + if msg.request.index.docs: + self._model.add([d.doc_id for d in msg.request.index.docs], + [d for d in msg.request.index.docs], + [d.weight for d in msg.request.index.docs]) + return True + else: + return False def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'): msg.response.search.ClearField('topk_results')