diff --git a/gnes/proto/__init__.py b/gnes/proto/__init__.py index 48798a7f..f2bc400e 100644 --- a/gnes/proto/__init__.py +++ b/gnes/proto/__init__.py @@ -28,7 +28,7 @@ class RequestGenerator: @staticmethod - def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs): + def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): for pi in batch_iterator(data, batch_size): req = gnes_pb2.Request() @@ -37,17 +37,19 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw d = req.index.docs.add() d.raw_bytes = raw_bytes d.weight = 1.0 + d.doc_type = doc_type yield req start_id += 1 @staticmethod - def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs): + def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): for pi in batch_iterator(data, batch_size): req = gnes_pb2.Request() req.request_id = str(start_id) for raw_bytes in pi: d = req.train.docs.add() d.raw_bytes = raw_bytes + d.doc_type = doc_type yield req start_id += 1 req = gnes_pb2.Request() @@ -57,13 +59,14 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw start_id += 1 @staticmethod - def query(query: bytes, top_k: int, start_id: int = 0, *args, **kwargs): + def query(query: bytes, top_k: int, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs): if top_k <= 0: raise ValueError('"top_k: %d" is not a valid number' % top_k) req = gnes_pb2.Request() req.request_id = str(start_id) req.search.query.raw_bytes = query + req.search.query.doc_type = doc_type req.search.top_k = top_k yield req diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index bdf5f428..16d264ef 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -36,7 +36,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'): weights = [] for d in msg.request.index.docs: - if d.chunks: + if len(d.chunks): all_vecs.append(blob2array(d.chunk_embeddings)) doc_ids += [d.doc_id] * len(d.chunks) if d.doc_type == 'TEXT':