Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(client): make query method as generator
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 15, 2019
1 parent af7c885 commit f16c672
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 24 deletions.
24 changes: 6 additions & 18 deletions gnes/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import sys
import time
import zipfile
from typing import Iterator
from typing import Iterator, Tuple

from termcolor import colored

from .base import GrpcClient
from ..proto import RequestGenerator, gnes_pb2
from ..proto import RequestGenerator


class CLIClient(GrpcClient):
Expand Down Expand Up @@ -54,37 +54,25 @@ def start(self):
finally:
self.close()

def train(self):
def train(self) -> None:
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.train(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()

def index(self):
def index(self) -> None:
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.index(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()

def query(self):
def query(self) -> Iterator[Tuple]:
for idx, q in enumerate(self.bytes_generator):
for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k):
resp = self._stub.Call(req)
self.query_callback(req, resp)

def query_callback(self, req: 'gnes_pb2.Request', resp: 'gnes_pb2.Response'):
"""
callback after get the query result
override this method to customize query behavior
:param resp: response
:param req: query
:return:
"""
print(req)
print(resp)
yield (req, resp)

@property
def bytes_generator(self) -> Iterator[bytes]:
Expand Down
11 changes: 6 additions & 5 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def train(self, bytes_gen: Iterator[bytes] = None, **kwargs):
:param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`.
:param kwargs: accepts all keyword arguments of `gnes client` CLI
"""
self._call_client(bytes_gen, mode='train', **kwargs)
self._get_client(bytes_gen, mode='train', **kwargs).start()


def index(self, bytes_gen: Iterator[bytes] = None, **kwargs):
"""Do indexing on the current flow
Expand All @@ -309,7 +310,7 @@ def index(self, bytes_gen: Iterator[bytes] = None, **kwargs):
:param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`.
:param kwargs: accepts all keyword arguments of `gnes client` CLI
"""
self._call_client(bytes_gen, mode='index', **kwargs)
self._get_client(bytes_gen, mode='index', **kwargs).start()

def query(self, bytes_gen: Iterator[bytes] = None, **kwargs):
"""Do indexing on the current flow
Expand All @@ -319,10 +320,10 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs):
:param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`.
:param kwargs: accepts all keyword arguments of `gnes client` CLI
"""
self._call_client(bytes_gen, mode='query', **kwargs)
yield from self._get_client(bytes_gen, mode='query', **kwargs).query()

@build_required(BuildLevel.RUNTIME)
def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs):
def _get_client(self, bytes_gen: Iterator[bytes] = None, **kwargs):
from ..cli.parser import set_client_cli_parser
from ..client.cli import CLIClient

Expand All @@ -332,7 +333,7 @@ def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs):
c = CLIClient(p_args, start_at_init=False)
if bytes_gen:
c.bytes_generator = bytes_gen
c.start()
return c

def add_frontend(self, *args, **kwargs) -> 'Flow':
"""Add a frontend to the current flow, a shortcut of :py:meth:`add(Service.Frontend)`.
Expand Down
2 changes: 1 addition & 1 deletion gnes/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.D
request_id_start += 1

@staticmethod
def query(query: bytes, top_k: int, request_id_start: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args,
def query(query: bytes, top_k: int, doc_type: int = gnes_pb2.Document.TEXT, request_id_start: int = 0, *args,
**kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)
Expand Down

0 comments on commit f16c672

Please sign in to comment.