diff --git a/gnes/cli/api.py b/gnes/cli/api.py index 376e24dd..4a25a934 100644 --- a/gnes/cli/api.py +++ b/gnes/cli/api.py @@ -86,7 +86,7 @@ def _client_http(args): def _client_cli(args): from ..client.cli import CLIClient - CLIClient(args) + CLIClient(args).start() def compose(args): diff --git a/gnes/client/base.py b/gnes/client/base.py index 481b5636..ce4bbd66 100644 --- a/gnes/client/base.py +++ b/gnes/client/base.py @@ -130,7 +130,6 @@ def __init__(self, args): ) self.logger.info('waiting channel to be ready...') grpc.channel_ready_future(self._channel).result() - self.logger.critical('gnes client ready!') # create new stub self.logger.info('create new stub...') @@ -138,6 +137,7 @@ def __init__(self, args): # attache response handler self.handler._context = self + self.logger.critical('gnes client ready at %s:%d!' % (self.args.grpc_host, self.args.grpc_port)) def call(self, request): resp = self._stub.call(request) @@ -158,13 +158,13 @@ def _handler_response_default(self, msg: 'gnes_pb2.Response'): pass def __enter__(self): - self.open() + self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def open(self): + def start(self): pass def close(self): diff --git a/gnes/client/cli.py b/gnes/client/cli.py index 1f7e8231..ebee67d7 100644 --- a/gnes/client/cli.py +++ b/gnes/client/cli.py @@ -26,10 +26,18 @@ class CLIClient(GrpcClient): - def __init__(self, args): + def __init__(self, args, start_at_init: bool = True): super().__init__(args) - getattr(self, self.args.mode)() - self.close() + if start_at_init: + self.start() + + def start(self): + try: + getattr(self, self.args.mode)() + except Exception as ex: + self.logger.error(ex) + finally: + self.close() def train(self): with ProgressBar(task_name=self.args.mode) as p_bar: diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 59a8f392..e9d8ff70 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -1,11 +1,12 @@ from collections import OrderedDict, defaultdict from contextlib import ExitStack from functools import wraps -from typing import Union, Tuple, List +from typing import Union, Tuple, List, Optional from ..cli.parser import set_router_parser, set_indexer_parser, \ set_frontend_parser, set_preprocessor_parser, \ - set_encoder_parser + set_encoder_parser, set_client_cli_parser +from ..client.cli import CLIClient from ..helper import set_logger from ..service.base import SocketType, BaseService, BetterEnum, ServiceManager from ..service.encoder import EncoderService @@ -62,6 +63,7 @@ class BuildLevel(BetterEnum): EMPTY = 0 GRAPH = 1 RUNTIME = 2 + RUNTIME_WITH_CLIENT = 3 def __init__(self, with_frontend: bool = True, **kwargs): self.logger = set_logger(self.__class__.__name__) @@ -125,14 +127,21 @@ def to_mermaid(self, left_right: bool = True): 'copy-paste the output and visualize it with: https://mermaidjs.github.io/mermaid-live-editor/') return mermaid_str - def train(self): - pass + def train(self, **kwargs): + self._call_client(mode='train', **kwargs) - def index(self): - pass + def index(self, **kwargs): + self._call_client(mode='index', **kwargs) - def query(self): - pass + def query(self, **kwargs): + self._call_client(mode='query', **kwargs) + + @_build_level(BuildLevel.RUNTIME) + def _call_client(self, **kwargs): + args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs) + p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port + p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host + CLIClient(p_args) def add_frontend(self, *args, **kwargs) -> 'Flow': """Add a frontend to the current flow, a shortcut of add(Service.Frontend) @@ -193,7 +202,7 @@ def add(self, service: 'Service', service_in = self._parse_service_endpoints(name, service_in, connect_to_last_service=True) service_out = self._parse_service_endpoints(name, service_out, connect_to_last_service=False) - args, p_args = self._get_parsed_args(service, kwargs) + args, p_args = self._get_parsed_args(Flow._service2parser[service], kwargs) self._service_nodes[name] = { 'service': service, @@ -233,7 +242,7 @@ def _parse_service_endpoints(self, cur_service_name, service_endpoint, connect_t raise ValueError('service_in=%s is not parsable' % service_endpoint) return set(service_endpoint) - def _get_parsed_args(self, service, kwargs): + def _get_parsed_args(self, service_arg_parser, kwargs): kwargs.update(self._common_kwargs) args = [] for k, v in kwargs.items(): @@ -251,12 +260,12 @@ def _get_parsed_args(self, service, kwargs): else: args.extend(['--%s' % k, str(v)]) try: - p_args, unknown_args = Flow._service2parser[service].parse_known_args(args) + p_args, unknown_args = service_arg_parser.parse_known_args(args) if unknown_args: self.logger.warning('not sure what these arguments are: %s' % unknown_args) except SystemExit: raise ValueError('bad arguments for service "%s", ' - 'you may want to double check your args "%s"' % (service, args)) + 'you may want to double check your args "%s"' % (service_arg_parser, args)) return args, p_args def _build_graph(self) -> 'Flow': @@ -292,6 +301,7 @@ def _build_graph(self) -> 'Flow': # # when a socket is BIND, then host must NOT be set, aka default host 0.0.0.0 # host_in and host_out is only set when corresponding socket is CONNECT + e_pargs.port_in = s_pargs.port_out if len(edges_with_same_start) > 1 and len(edges_with_same_end) == 1: s_pargs.socket_out = SocketType.PUB_BIND @@ -336,17 +346,25 @@ def _build_graph(self) -> 'Flow': self._build_level = Flow.BuildLevel.GRAPH return self - def build(self, backend='thread', *args, **kwargs) -> 'Flow': + def build(self, backend: Optional[str] = 'thread', *args, **kwargs) -> 'Flow': self._build_graph() - if backend in {'thread', 'process'}: + if not backend: + self.logger.warning('no specified backend, build_level stays at %s, ' + 'and you can not run this flow.' % self._build_level) + elif backend in {'thread', 'process'}: self._service_contexts.clear() for v in self._service_nodes.values(): - v['parsed_args'].parallel_backend = backend - s = self._service2builder[v['service']](v['parsed_args']) + p_args = v['parsed_args'] + p_args.parallel_backend = backend + # for thread and process backend which runs locally, host_in and host_out should not be set + p_args.host_in = BaseService.default_host + p_args.host_out = BaseService.default_host + s = self._service2builder[v['service']](p_args) self._service_contexts.append(s) + self._build_level = Flow.BuildLevel.RUNTIME else: raise NotImplementedError('backend=%s is not supported yet' % backend) - self._build_level = Flow.BuildLevel.RUNTIME + return self def __call__(self, *args, **kwargs): diff --git a/tests/test_client_cli.py b/tests/test_client_cli.py index d7413442..357d0b3b 100644 --- a/tests/test_client_cli.py +++ b/tests/test_client_cli.py @@ -38,7 +38,7 @@ def test_cli(self): '--port_out', str(args.port_in), '--socket_in', str(SocketType.PULL_CONNECT), '--socket_out', str(SocketType.PUSH_CONNECT), - '--yaml_path', 'BaseRouter' + '--yaml_path', 'BaseRouter', ]) with RouterService(p_args), FrontendService(args): diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index 10180bf7..e1546bbd 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -1,22 +1,42 @@ +import os import unittest +from gnes.cli.parser import set_client_cli_parser from gnes.flow import Flow, Service as gfs class TestGNESFlow(unittest.TestCase): + def setUp(self): + self.dirname = os.path.dirname(__file__) + self.test_file = os.path.join(self.dirname, 'sonnets_small.txt') + self.index_args = set_client_cli_parser().parse_args([ + '--mode', 'index', + '--txt_file', self.test_file, + '--batch_size', '4' + ]) + os.unsetenv('http_proxy') + os.unsetenv('https_proxy') + def test_flow1(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Router, yaml_path='BaseRouter').build()) print(f._service_edges) print(f.to_mermaid()) - def test_flow1_ctx(self): + def test_flow1_ctx_empty(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Router, yaml_path='BaseRouter')) with f(backend='process'): pass + def test_flow1_ctx(self): + flow = (Flow(check_version=False, route_table=True) + .add(gfs.Router, yaml_path='BaseRouter')) + with flow(backend='process') as f: + # CLIClient(self.index_args) + f.index(txt_file=self.test_file, batch_size=4) + def test_flow2(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Router, yaml_path='BaseRouter') @@ -27,7 +47,7 @@ def test_flow2(self): .add(gfs.Router, yaml_path='BaseRouter') .add(gfs.Router, yaml_path='BaseRouter') .add(gfs.Router, yaml_path='BaseRouter') - .build()) + .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -35,7 +55,7 @@ def test_flow3(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Router, name='r0', service_out=gfs.Frontend, yaml_path='BaseRouter') .add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter') - .build()) + .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -44,7 +64,7 @@ def test_flow4(self): .add(gfs.Router, name='r0', yaml_path='BaseRouter') .add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter') .add(gfs.Router, name='reduce', service_in=['r0', 'r1'], yaml_path='BaseRouter') - .build()) + .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -56,6 +76,6 @@ def test_flow5(self): .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep') .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', num_part=2, service_in=['vec_idx', 'doc_idx']) - .build()) + .build(backend=None)) print(f._service_edges) print(f.to_mermaid())