diff --git a/gnes/client/cli.py b/gnes/client/cli.py index ebee67d7..6e701416 100644 --- a/gnes/client/cli.py +++ b/gnes/client/cli.py @@ -28,9 +28,24 @@ class CLIClient(GrpcClient): def __init__(self, args, start_at_init: bool = True): super().__init__(args) + self._bytes_generator = self._get_bytes_generator_from_args(args) if start_at_init: self.start() + @staticmethod + def _get_bytes_generator_from_args(args): + if args.txt_file: + all_bytes = (v.encode() for v in args.txt_file) + elif args.image_zip_file: + zipfile_ = zipfile.ZipFile(args.image_zip_file) + all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) + elif args.video_zip_file: + zipfile_ = zipfile.ZipFile(args.video_zip_file) + all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) + else: + all_bytes = None + return all_bytes + def start(self): try: getattr(self, self.args.mode)() @@ -72,18 +87,16 @@ def query_callback(self, req: 'gnes_pb2.Request', resp: 'gnes_pb2.Response'): @property def bytes_generator(self) -> Generator[bytes, None, None]: - if self.args.txt_file: - all_bytes = (v.encode() for v in self.args.txt_file) - elif self.args.image_zip_file: - zipfile_ = zipfile.ZipFile(self.args.image_zip_file) - all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) - elif self.args.video_zip_file: - zipfile_ = zipfile.ZipFile(self.args.video_zip_file) - all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) + if self.bytes_generator: + return self._bytes_generator else: - raise AttributeError('--txt_file, --image_zip_file, --video_zip_file one must be given') + raise ValueError('bytes_generator is empty or not set') - return all_bytes + @bytes_generator.setter + def bytes_generator(self, bytes_gen: Generator[bytes, None, None]): + if self._bytes_generator: + self.logger.warning('bytes_generator is not empty, overrided') + self._bytes_generator = bytes_gen class ProgressBar: diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index e9d8ff70..5c15f5ce 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -1,7 +1,7 @@ from collections import OrderedDict, defaultdict from contextlib import ExitStack from functools import wraps -from typing import Union, Tuple, List, Optional +from typing import Union, Tuple, List, Optional, Generator from ..cli.parser import set_router_parser, set_indexer_parser, \ set_frontend_parser, set_preprocessor_parser, \ @@ -63,7 +63,6 @@ class BuildLevel(BetterEnum): EMPTY = 0 GRAPH = 1 RUNTIME = 2 - RUNTIME_WITH_CLIENT = 3 def __init__(self, with_frontend: bool = True, **kwargs): self.logger = set_logger(self.__class__.__name__) @@ -127,21 +126,25 @@ def to_mermaid(self, left_right: bool = True): 'copy-paste the output and visualize it with: https://mermaidjs.github.io/mermaid-live-editor/') return mermaid_str - def train(self, **kwargs): - self._call_client(mode='train', **kwargs) + def train(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): + self._call_client(bytes_gen, mode='train', **kwargs) - def index(self, **kwargs): - self._call_client(mode='index', **kwargs) + def index(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): + self._call_client(bytes_gen, mode='index', **kwargs) - def query(self, **kwargs): - self._call_client(mode='query', **kwargs) + def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): + self._call_client(bytes_gen, mode='query', **kwargs) @_build_level(BuildLevel.RUNTIME) - def _call_client(self, **kwargs): + def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs) p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host - CLIClient(p_args) + c = CLIClient(p_args, start_at_init=False) + if bytes_gen: + c.bytes_generator = bytes_gen + with c: + pass def add_frontend(self, *args, **kwargs) -> 'Flow': """Add a frontend to the current flow, a shortcut of add(Service.Frontend) @@ -383,4 +386,4 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if hasattr(self, '_service_stack'): - self._service_stack.close() + self._service_stack.close() \ No newline at end of file diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index e1546bbd..ab797fbf 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -31,11 +31,12 @@ def test_flow1_ctx_empty(self): pass def test_flow1_ctx(self): - flow = (Flow(check_version=False, route_table=True) + flow = (Flow(check_version=False, route_table=False) .add(gfs.Router, yaml_path='BaseRouter')) - with flow(backend='process') as f: - # CLIClient(self.index_args) + with flow(backend='process') as f, open(self.test_file) as fp: f.index(txt_file=self.test_file, batch_size=4) + f.index(bytes_gen=(v.encode() for v in fp), batch_size=4) + f.train(txt_file=self.test_file, batch_size=4) def test_flow2(self): f = (Flow(check_version=False, route_table=True)