From f16c672e5eb7b3a88e7e4795d135af8d22de9c44 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Tue, 15 Oct 2019 18:02:50 +0800 Subject: [PATCH] refactor(client): make query method as generator --- gnes/client/cli.py | 24 ++++++------------------ gnes/flow/__init__.py | 11 ++++++----- gnes/proto/__init__.py | 2 +- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/gnes/client/cli.py b/gnes/client/cli.py index 999ed215..23d86159 100644 --- a/gnes/client/cli.py +++ b/gnes/client/cli.py @@ -17,12 +17,12 @@ import sys import time import zipfile -from typing import Iterator +from typing import Iterator, Tuple from termcolor import colored from .base import GrpcClient -from ..proto import RequestGenerator, gnes_pb2 +from ..proto import RequestGenerator class CLIClient(GrpcClient): @@ -54,37 +54,25 @@ def start(self): finally: self.close() - def train(self): + def train(self) -> None: with ProgressBar(task_name=self.args.mode) as p_bar: for _ in self._stub.StreamCall(RequestGenerator.train(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size)): p_bar.update() - def index(self): + def index(self) -> None: with ProgressBar(task_name=self.args.mode) as p_bar: for _ in self._stub.StreamCall(RequestGenerator.index(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size)): p_bar.update() - def query(self): + def query(self) -> Iterator[Tuple]: for idx, q in enumerate(self.bytes_generator): for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k): resp = self._stub.Call(req) - self.query_callback(req, resp) - - def query_callback(self, req: 'gnes_pb2.Request', resp: 'gnes_pb2.Response'): - """ - callback after get the query result - override this method to customize query behavior - - :param resp: response - :param req: query - :return: - """ - print(req) - print(resp) + yield (req, resp) @property def bytes_generator(self) -> Iterator[bytes]: diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 4dceff59..eb2ef419 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -299,7 +299,8 @@ def train(self, bytes_gen: Iterator[bytes] = None, **kwargs): :param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`. :param kwargs: accepts all keyword arguments of `gnes client` CLI """ - self._call_client(bytes_gen, mode='train', **kwargs) + self._get_client(bytes_gen, mode='train', **kwargs).start() + def index(self, bytes_gen: Iterator[bytes] = None, **kwargs): """Do indexing on the current flow @@ -309,7 +310,7 @@ def index(self, bytes_gen: Iterator[bytes] = None, **kwargs): :param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`. :param kwargs: accepts all keyword arguments of `gnes client` CLI """ - self._call_client(bytes_gen, mode='index', **kwargs) + self._get_client(bytes_gen, mode='index', **kwargs).start() def query(self, bytes_gen: Iterator[bytes] = None, **kwargs): """Do indexing on the current flow @@ -319,10 +320,10 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs): :param bytes_gen: An iterator of bytes. If not given, then you have to specify it in `kwargs`. :param kwargs: accepts all keyword arguments of `gnes client` CLI """ - self._call_client(bytes_gen, mode='query', **kwargs) + yield from self._get_client(bytes_gen, mode='query', **kwargs).query() @build_required(BuildLevel.RUNTIME) - def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs): + def _get_client(self, bytes_gen: Iterator[bytes] = None, **kwargs): from ..cli.parser import set_client_cli_parser from ..client.cli import CLIClient @@ -332,7 +333,7 @@ def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs): c = CLIClient(p_args, start_at_init=False) if bytes_gen: c.bytes_generator = bytes_gen - c.start() + return c def add_frontend(self, *args, **kwargs) -> 'Flow': """Add a frontend to the current flow, a shortcut of :py:meth:`add(Service.Frontend)`. diff --git a/gnes/proto/__init__.py b/gnes/proto/__init__.py index 6f531846..312f4a99 100644 --- a/gnes/proto/__init__.py +++ b/gnes/proto/__init__.py @@ -74,7 +74,7 @@ def train(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.D request_id_start += 1 @staticmethod - def query(query: bytes, top_k: int, request_id_start: int = 0, doc_type: int = gnes_pb2.Document.TEXT, *args, + def query(query: bytes, top_k: int, doc_type: int = gnes_pb2.Document.TEXT, request_id_start: int = 0, *args, **kwargs): if top_k <= 0: raise ValueError('"top_k: %d" is not a valid number' % top_k)