Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(flow): add client to flow
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 9, 2019
1 parent 0fc8b3d commit 990593b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
33 changes: 23 additions & 10 deletions gnes/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,24 @@
class CLIClient(GrpcClient):
def __init__(self, args, start_at_init: bool = True):
super().__init__(args)
self._bytes_generator = self._get_bytes_generator_from_args(args)
if start_at_init:
self.start()

@staticmethod
def _get_bytes_generator_from_args(args):
if args.txt_file:
all_bytes = (v.encode() for v in args.txt_file)
elif args.image_zip_file:
zipfile_ = zipfile.ZipFile(args.image_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
elif args.video_zip_file:
zipfile_ = zipfile.ZipFile(args.video_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
else:
all_bytes = None
return all_bytes

def start(self):
try:
getattr(self, self.args.mode)()
Expand Down Expand Up @@ -72,18 +87,16 @@ def query_callback(self, req: 'gnes_pb2.Request', resp: 'gnes_pb2.Response'):

@property
def bytes_generator(self) -> Generator[bytes, None, None]:
if self.args.txt_file:
all_bytes = (v.encode() for v in self.args.txt_file)
elif self.args.image_zip_file:
zipfile_ = zipfile.ZipFile(self.args.image_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
elif self.args.video_zip_file:
zipfile_ = zipfile.ZipFile(self.args.video_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
if self.bytes_generator:
return self._bytes_generator
else:
raise AttributeError('--txt_file, --image_zip_file, --video_zip_file one must be given')
raise ValueError('bytes_generator is empty or not set')

return all_bytes
@bytes_generator.setter
def bytes_generator(self, bytes_gen: Generator[bytes, None, None]):
if self._bytes_generator:
self.logger.warning('bytes_generator is not empty, overrided')
self._bytes_generator = bytes_gen


class ProgressBar:
Expand Down
25 changes: 14 additions & 11 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from functools import wraps
from typing import Union, Tuple, List, Optional
from typing import Union, Tuple, List, Optional, Generator

from ..cli.parser import set_router_parser, set_indexer_parser, \
set_frontend_parser, set_preprocessor_parser, \
Expand Down Expand Up @@ -63,7 +63,6 @@ class BuildLevel(BetterEnum):
EMPTY = 0
GRAPH = 1
RUNTIME = 2
RUNTIME_WITH_CLIENT = 3

def __init__(self, with_frontend: bool = True, **kwargs):
self.logger = set_logger(self.__class__.__name__)
Expand Down Expand Up @@ -127,21 +126,25 @@ def to_mermaid(self, left_right: bool = True):
'copy-paste the output and visualize it with: https://mermaidjs.github.io/mermaid-live-editor/')
return mermaid_str

def train(self, **kwargs):
self._call_client(mode='train', **kwargs)
def train(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):
self._call_client(bytes_gen, mode='train', **kwargs)

def index(self, **kwargs):
self._call_client(mode='index', **kwargs)
def index(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):
self._call_client(bytes_gen, mode='index', **kwargs)

def query(self, **kwargs):
self._call_client(mode='query', **kwargs)
def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):
self._call_client(bytes_gen, mode='query', **kwargs)

@_build_level(BuildLevel.RUNTIME)
def _call_client(self, **kwargs):
def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):
args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
CLIClient(p_args)
c = CLIClient(p_args, start_at_init=False)
if bytes_gen:
c.bytes_generator = bytes_gen
with c:
pass

def add_frontend(self, *args, **kwargs) -> 'Flow':
"""Add a frontend to the current flow, a shortcut of add(Service.Frontend)
Expand Down Expand Up @@ -383,4 +386,4 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_service_stack'):
self._service_stack.close()
self._service_stack.close()
7 changes: 4 additions & 3 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def test_flow1_ctx_empty(self):
pass

def test_flow1_ctx(self):
flow = (Flow(check_version=False, route_table=True)
flow = (Flow(check_version=False, route_table=False)
.add(gfs.Router, yaml_path='BaseRouter'))
with flow(backend='process') as f:
# CLIClient(self.index_args)
with flow(backend='process') as f, open(self.test_file) as fp:
f.index(txt_file=self.test_file, batch_size=4)
f.index(bytes_gen=(v.encode() for v in fp), batch_size=4)
f.train(txt_file=self.test_file, batch_size=4)

def test_flow2(self):
f = (Flow(check_version=False, route_table=True)
Expand Down

0 comments on commit 990593b

Please sign in to comment.