Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(service): fix bug in req Generator add doc_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Jul 26, 2019
1 parent 5743e25 commit 80e234e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions gnes/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class RequestGenerator:
@staticmethod
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):

for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
Expand All @@ -37,17 +37,19 @@ def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw
d = req.index.docs.add()
d.raw_bytes = raw_bytes
d.weight = 1.0
d.doc_type = doc_type
yield req
start_id += 1

@staticmethod
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = str(start_id)
for raw_bytes in pi:
d = req.train.docs.add()
d.raw_bytes = raw_bytes
d.doc_type = doc_type
yield req
start_id += 1
req = gnes_pb2.Request()
Expand All @@ -57,13 +59,14 @@ def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kw
start_id += 1

@staticmethod
def query(query: bytes, top_k: int, start_id: int = 0, *args, **kwargs):
def query(query: bytes, top_k: int, start_id: int = 0, doc_type: str = 'TEXT', *args, **kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)

req = gnes_pb2.Request()
req.request_id = str(start_id)
req.search.query.raw_bytes = query
req.search.query.doc_type = doc_type
req.search.top_k = top_k
yield req

Expand Down
2 changes: 1 addition & 1 deletion gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
weights = []

for d in msg.request.index.docs:
if d.chunks:
if len(d.chunks):
all_vecs.append(blob2array(d.chunk_embeddings))
doc_ids += [d.doc_id] * len(d.chunks)
if d.doc_type == 'TEXT':
Expand Down

0 comments on commit 80e234e

Please sign in to comment.