Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(flow): add client to flow
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 8, 2019
1 parent 43b9d01 commit 0fc8b3d
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 30 deletions.
2 changes: 1 addition & 1 deletion gnes/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _client_http(args):

def _client_cli(args):
from ..client.cli import CLIClient
CLIClient(args)
CLIClient(args).start()


def compose(args):
Expand Down
6 changes: 3 additions & 3 deletions gnes/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def __init__(self, args):
)
self.logger.info('waiting channel to be ready...')
grpc.channel_ready_future(self._channel).result()
self.logger.critical('gnes client ready!')

# create new stub
self.logger.info('create new stub...')
self._stub = gnes_pb2_grpc.GnesRPCStub(self._channel)

# attache response handler
self.handler._context = self
self.logger.critical('gnes client ready at %s:%d!' % (self.args.grpc_host, self.args.grpc_port))

def call(self, request):
resp = self._stub.call(request)
Expand All @@ -158,13 +158,13 @@ def _handler_response_default(self, msg: 'gnes_pb2.Response'):
pass

def __enter__(self):
self.open()
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def open(self):
def start(self):
pass

def close(self):
Expand Down
14 changes: 11 additions & 3 deletions gnes/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@


class CLIClient(GrpcClient):
def __init__(self, args):
def __init__(self, args, start_at_init: bool = True):
super().__init__(args)
getattr(self, self.args.mode)()
self.close()
if start_at_init:
self.start()

def start(self):
try:
getattr(self, self.args.mode)()
except Exception as ex:
self.logger.error(ex)
finally:
self.close()

def train(self):
with ProgressBar(task_name=self.args.mode) as p_bar:
Expand Down
52 changes: 35 additions & 17 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from functools import wraps
from typing import Union, Tuple, List
from typing import Union, Tuple, List, Optional

from ..cli.parser import set_router_parser, set_indexer_parser, \
set_frontend_parser, set_preprocessor_parser, \
set_encoder_parser
set_encoder_parser, set_client_cli_parser
from ..client.cli import CLIClient
from ..helper import set_logger
from ..service.base import SocketType, BaseService, BetterEnum, ServiceManager
from ..service.encoder import EncoderService
Expand Down Expand Up @@ -62,6 +63,7 @@ class BuildLevel(BetterEnum):
EMPTY = 0
GRAPH = 1
RUNTIME = 2
RUNTIME_WITH_CLIENT = 3

def __init__(self, with_frontend: bool = True, **kwargs):
self.logger = set_logger(self.__class__.__name__)
Expand Down Expand Up @@ -125,14 +127,21 @@ def to_mermaid(self, left_right: bool = True):
'copy-paste the output and visualize it with: https://mermaidjs.github.io/mermaid-live-editor/')
return mermaid_str

def train(self):
pass
def train(self, **kwargs):
self._call_client(mode='train', **kwargs)

def index(self):
pass
def index(self, **kwargs):
self._call_client(mode='index', **kwargs)

def query(self):
pass
def query(self, **kwargs):
self._call_client(mode='query', **kwargs)

@_build_level(BuildLevel.RUNTIME)
def _call_client(self, **kwargs):
args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
CLIClient(p_args)

def add_frontend(self, *args, **kwargs) -> 'Flow':
"""Add a frontend to the current flow, a shortcut of add(Service.Frontend)
Expand Down Expand Up @@ -193,7 +202,7 @@ def add(self, service: 'Service',
service_in = self._parse_service_endpoints(name, service_in, connect_to_last_service=True)
service_out = self._parse_service_endpoints(name, service_out, connect_to_last_service=False)

args, p_args = self._get_parsed_args(service, kwargs)
args, p_args = self._get_parsed_args(Flow._service2parser[service], kwargs)

self._service_nodes[name] = {
'service': service,
Expand Down Expand Up @@ -233,7 +242,7 @@ def _parse_service_endpoints(self, cur_service_name, service_endpoint, connect_t
raise ValueError('service_in=%s is not parsable' % service_endpoint)
return set(service_endpoint)

def _get_parsed_args(self, service, kwargs):
def _get_parsed_args(self, service_arg_parser, kwargs):
kwargs.update(self._common_kwargs)
args = []
for k, v in kwargs.items():
Expand All @@ -251,12 +260,12 @@ def _get_parsed_args(self, service, kwargs):
else:
args.extend(['--%s' % k, str(v)])
try:
p_args, unknown_args = Flow._service2parser[service].parse_known_args(args)
p_args, unknown_args = service_arg_parser.parse_known_args(args)
if unknown_args:
self.logger.warning('not sure what these arguments are: %s' % unknown_args)
except SystemExit:
raise ValueError('bad arguments for service "%s", '
'you may want to double check your args "%s"' % (service, args))
'you may want to double check your args "%s"' % (service_arg_parser, args))
return args, p_args

def _build_graph(self) -> 'Flow':
Expand Down Expand Up @@ -292,6 +301,7 @@ def _build_graph(self) -> 'Flow':
#
# when a socket is BIND, then host must NOT be set, aka default host 0.0.0.0
# host_in and host_out is only set when corresponding socket is CONNECT
e_pargs.port_in = s_pargs.port_out

if len(edges_with_same_start) > 1 and len(edges_with_same_end) == 1:
s_pargs.socket_out = SocketType.PUB_BIND
Expand Down Expand Up @@ -336,17 +346,25 @@ def _build_graph(self) -> 'Flow':
self._build_level = Flow.BuildLevel.GRAPH
return self

def build(self, backend='thread', *args, **kwargs) -> 'Flow':
def build(self, backend: Optional[str] = 'thread', *args, **kwargs) -> 'Flow':
self._build_graph()
if backend in {'thread', 'process'}:
if not backend:
self.logger.warning('no specified backend, build_level stays at %s, '
'and you can not run this flow.' % self._build_level)
elif backend in {'thread', 'process'}:
self._service_contexts.clear()
for v in self._service_nodes.values():
v['parsed_args'].parallel_backend = backend
s = self._service2builder[v['service']](v['parsed_args'])
p_args = v['parsed_args']
p_args.parallel_backend = backend
# for thread and process backend which runs locally, host_in and host_out should not be set
p_args.host_in = BaseService.default_host
p_args.host_out = BaseService.default_host
s = self._service2builder[v['service']](p_args)
self._service_contexts.append(s)
self._build_level = Flow.BuildLevel.RUNTIME
else:
raise NotImplementedError('backend=%s is not supported yet' % backend)
self._build_level = Flow.BuildLevel.RUNTIME

return self

def __call__(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_cli(self):
'--port_out', str(args.port_in),
'--socket_in', str(SocketType.PULL_CONNECT),
'--socket_out', str(SocketType.PUSH_CONNECT),
'--yaml_path', 'BaseRouter'
'--yaml_path', 'BaseRouter',
])

with RouterService(p_args), FrontendService(args):
Expand Down
30 changes: 25 additions & 5 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
import os
import unittest

from gnes.cli.parser import set_client_cli_parser
from gnes.flow import Flow, Service as gfs


class TestGNESFlow(unittest.TestCase):

def setUp(self):
self.dirname = os.path.dirname(__file__)
self.test_file = os.path.join(self.dirname, 'sonnets_small.txt')
self.index_args = set_client_cli_parser().parse_args([
'--mode', 'index',
'--txt_file', self.test_file,
'--batch_size', '4'
])
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')

def test_flow1(self):
f = (Flow(check_version=False, route_table=True)
.add(gfs.Router, yaml_path='BaseRouter').build())
print(f._service_edges)
print(f.to_mermaid())

def test_flow1_ctx(self):
def test_flow1_ctx_empty(self):
f = (Flow(check_version=False, route_table=True)
.add(gfs.Router, yaml_path='BaseRouter'))
with f(backend='process'):
pass

def test_flow1_ctx(self):
flow = (Flow(check_version=False, route_table=True)
.add(gfs.Router, yaml_path='BaseRouter'))
with flow(backend='process') as f:
# CLIClient(self.index_args)
f.index(txt_file=self.test_file, batch_size=4)

def test_flow2(self):
f = (Flow(check_version=False, route_table=True)
.add(gfs.Router, yaml_path='BaseRouter')
Expand All @@ -27,15 +47,15 @@ def test_flow2(self):
.add(gfs.Router, yaml_path='BaseRouter')
.add(gfs.Router, yaml_path='BaseRouter')
.add(gfs.Router, yaml_path='BaseRouter')
.build())
.build(backend=None))
print(f._service_edges)
print(f.to_mermaid())

def test_flow3(self):
f = (Flow(check_version=False, route_table=True)
.add(gfs.Router, name='r0', service_out=gfs.Frontend, yaml_path='BaseRouter')
.add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter')
.build())
.build(backend=None))
print(f._service_edges)
print(f.to_mermaid())

Expand All @@ -44,7 +64,7 @@ def test_flow4(self):
.add(gfs.Router, name='r0', yaml_path='BaseRouter')
.add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter')
.add(gfs.Router, name='reduce', service_in=['r0', 'r1'], yaml_path='BaseRouter')
.build())
.build(backend=None))
print(f._service_edges)
print(f.to_mermaid())

Expand All @@ -56,6 +76,6 @@ def test_flow5(self):
.add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep')
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx'])
.build())
.build(backend=None))
print(f._service_edges)
print(f.to_mermaid())

0 comments on commit 0fc8b3d

Please sign in to comment.