diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index bc92c313..a0043420 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -107,7 +107,7 @@ def register_class(cls): # print('reg class: %s' % cls.__name__) cls.__init__ = TrainableType._store_init_kwargs(cls.__init__) if os.environ.get('GNES_PROFILING', False): - for f_name in ['train', 'encode', 'add', 'query']: + for f_name in ['train', 'encode', 'add', 'query', 'index']: if getattr(cls, f_name, None): setattr(cls, f_name, profiling(getattr(cls, f_name))) diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 1f5d793a..4dceff59 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -1,66 +1,15 @@ import copy from collections import OrderedDict, defaultdict from contextlib import ExitStack -from functools import wraps from typing import Union, Tuple, List, Optional, Iterator -from ..cli.parser import set_router_parser, set_indexer_parser, \ - set_frontend_parser, set_preprocessor_parser, \ - set_encoder_parser, set_client_cli_parser -from ..client.cli import CLIClient +from .helper import * +from ..base import TrainableBase from ..helper import set_logger -from ..service.base import SocketType, BaseService, BetterEnum, ServiceManager -from ..service.encoder import EncoderService -from ..service.frontend import FrontendService -from ..service.indexer import IndexerService -from ..service.preprocessor import PreprocessorService -from ..service.router import RouterService +from ..service.base import SocketType, BaseService -class Service(BetterEnum): - Frontend = 0 - Encoder = 1 - Router = 2 - Indexer = 3 - Preprocessor = 4 - - -class FlowImcompleteError(ValueError): - """Exception when the flow missing some important component to run""" - - -class FlowTopologyError(ValueError): - """Exception when the topology is ambiguous""" - - -class FlowMissingNode(ValueError): - """Exception when the topology is ambiguous""" - - -class FlowBuildLevelMismatch(ValueError): - """Exception when required level is higher than the current build level""" - - -def _build_level(required_level: 'Flow.BuildLevel'): - def __build_level(func): - @wraps(func) - def arg_wrapper(self, *args, **kwargs): - if hasattr(self, '_build_level'): - if self._build_level.value >= required_level.value: - return func(self, *args, **kwargs) - else: - raise FlowBuildLevelMismatch( - 'build_level check failed for %r, required level: %s, actual level: %s' % ( - func, required_level, self._build_level)) - else: - raise AttributeError('%r has no attribute "_build_level"' % self) - - return arg_wrapper - - return __build_level - - -class Flow: +class Flow(TrainableBase): """ GNES Flow: an intuitive way to build workflow for GNES. @@ -91,60 +40,85 @@ class Flow: """ - _service2parser = { - Service.Encoder: set_encoder_parser, - Service.Router: set_router_parser, - Service.Indexer: set_indexer_parser, - Service.Frontend: set_frontend_parser, - Service.Preprocessor: set_preprocessor_parser, - } - _service2builder = { - Service.Encoder: lambda x: ServiceManager(EncoderService, x), - Service.Router: lambda x: ServiceManager(RouterService, x), - Service.Indexer: lambda x: ServiceManager(IndexerService, x), - Service.Preprocessor: lambda x: ServiceManager(PreprocessorService, x), - Service.Frontend: FrontendService, - } - - class BuildLevel(BetterEnum): - EMPTY = 0 - GRAPH = 1 - RUNTIME = 2 - - def __init__(self, with_frontend: bool = True, **kwargs): + def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, **kwargs): """ Create a new Flow object. :param with_frontend: adding frontend service to the flow + :param is_trained: indicating whether this flow is trained or not. if set to False then :py:meth:`index` + and :py:meth:`query` can not be called before :py:meth:`train` :param kwargs: keyword-value arguments that will be shared by all services """ + super().__init__(*args, **kwargs) self.logger = set_logger(self.__class__.__name__) self._service_nodes = OrderedDict() self._service_edges = {} - self._service_name_counter = {k: 0 for k in Flow._service2parser.keys()} + self._service_name_counter = {k: 0 for k in service_map.keys()} self._service_contexts = [] self._last_changed_service = [] self._common_kwargs = kwargs self._frontend = None self._client = None - self._build_level = Flow.BuildLevel.EMPTY + self._build_level = BuildLevel.EMPTY self._backend = None self._init_with_frontend = False + self.is_trained = is_trained if with_frontend: self.add_frontend(copy_flow=False) self._init_with_frontend = True else: self.logger.warning('with_frontend is set to False, you need to add_frontend() by yourself') - @_build_level(BuildLevel.GRAPH) - def to_swarm_yaml(self) -> str: - swarm_yml = '' - return swarm_yml + @build_required(BuildLevel.GRAPH) + def to_k8s_yaml(self) -> str: + raise NotImplementedError + + @build_required(BuildLevel.GRAPH) + def to_shell_script(self) -> str: + raise NotImplementedError + + @build_required(BuildLevel.GRAPH) + def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str: + """ + Generate the docker swarm YAML compose file + + :param image: the default GNES docker image + :return: the generated YAML compose file + """ + from ruamel.yaml import YAML, StringIO + _yaml = YAML() + swarm_yml = {'version': '3.4', + 'services': {}} + + for k, v in self._service_nodes.items(): + defaults_kwargs, _ = service_map[v['service']]['parser']().parse_known_args( + ['--yaml_path', 'TrainableBase']) + non_default_kwargs = {k: v for k, v in vars(v['parsed_args']).items() if getattr(defaults_kwargs, k) != v} + if not isinstance(non_default_kwargs.get('yaml_path', ''), str): + non_default_kwargs['yaml_path'] = v['kwargs']['yaml_path'] + + num_replicas = None + if 'num_parallel' in non_default_kwargs: + num_replicas = non_default_kwargs.pop('num_parallel') + + swarm_yml['services'][k] = { + 'image': v['kwargs'].get('image', image), + 'command': '%s %s' % ( + service_map[v['service']]['cmd'], + ' '.join(['--%s %s' % (k, v) for k, v in non_default_kwargs.items()])) + } + if num_replicas and num_replicas > 1: + swarm_yml['services'][k]['deploy'] = {'replicas': num_replicas} + + stream = StringIO() + _yaml.dump(swarm_yml, stream) + return stream.getvalue().strip() def to_python_code(self, indent: int = 4) -> str: """ Generate the python code of this flow + :param indent: the number of whitespaces of indent :return: the generated python code """ py_code = ['from gnes.flow import Flow', ''] @@ -162,11 +136,11 @@ def to_python_code(self, indent: int = 4) -> str: kwargs = OrderedDict() kwargs['service'] = str(v['service']) kwargs['name'] = k - kwargs['service_in'] = '[%s]' % ( + kwargs['recv_from'] = '[%s]' % ( ','.join({'\'%s\'' % k for k in v['incomes'] if k in known_service})) - if kwargs['service_in'] == '[\'%s\']' % last_add_name: - kwargs.pop('service_in') - kwargs['service_out'] = '[%s]' % (','.join({'\'%s\'' % k for k in v['outgoings'] if k in known_service})) + if kwargs['recv_from'] == '[\'%s\']' % last_add_name: + kwargs.pop('recv_from') + kwargs['send_to'] = '[%s]' % (','.join({'\'%s\'' % k for k in v['outgoings'] if k in known_service})) known_service.add(k) last_add_name = k @@ -196,7 +170,7 @@ def to_python_code(self, indent: int = 4) -> str: return '\n'.join(py_code) - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_mermaid(self, left_right: bool = True) -> str: """ Output the mermaid graph for visualization @@ -284,7 +258,7 @@ def to_mermaid(self, left_right: bool = True) -> str: return mermaid_str - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_url(self, **kwargs) -> str: """ Rendering the current flow as a url points to a SVG, it needs internet connection @@ -297,7 +271,7 @@ def to_url(self, **kwargs) -> str: encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8') return 'https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_jpg(self, path: str = 'flow.jpg', **kwargs) -> None: """ Rendering the current flow as a jpg image, this will call :py:meth:`to_mermaid` and it needs internet connection @@ -347,9 +321,12 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs): """ self._call_client(bytes_gen, mode='query', **kwargs) - @_build_level(BuildLevel.RUNTIME) + @build_required(BuildLevel.RUNTIME) def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs): - args, p_args = self._get_parsed_args(self, set_client_cli_parser, kwargs) + from ..cli.parser import set_client_cli_parser + from ..client.cli import CLIClient + + args, p_args, unk_args = self._get_parsed_args(self, set_client_cli_parser, kwargs) p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host c = CLIClient(p_args, start_at_init=False) @@ -391,7 +368,7 @@ def set_last_service(self, name: str, copy_flow: bool = True) -> 'Flow': op_flow = copy.deepcopy(self) if copy_flow else self if name not in op_flow._service_nodes: - raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + raise FlowMissingNode('recv_from: %s can not be found in this Flow' % name) if op_flow._last_changed_service and name == op_flow._last_changed_service[-1]: pass @@ -400,12 +377,12 @@ def set_last_service(self, name: str, copy_flow: bool = True) -> 'Flow': # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow - def set(self, name: str, service_in: Union[str, Tuple[str], List[str], 'Service'] = None, - service_out: Union[str, Tuple[str], List[str], 'Service'] = None, + def set(self, name: str, recv_from: Union[str, Tuple[str], List[str], 'Service'] = None, + send_to: Union[str, Tuple[str], List[str], 'Service'] = None, copy_flow: bool = True, clear_old_attr: bool = False, as_last_service: bool = False, @@ -415,9 +392,9 @@ def set(self, name: str, service_in: Union[str, Tuple[str], List[str], 'Service' For the attributes or kwargs that aren't given, they will remain unchanged as before. :param name: the name of the existing service - :param service_in: the name of the service(s) that this service receives data from. + :param recv_from: the name of the service(s) that this service receives data from. One can also use 'Service.Frontend' to indicate the connection with the frontend. - :param service_out: the name of the service(s) that this service sends data to. + :param send_to: the name of the service(s) that this service sends data to. One can also use 'Service.Frontend' to indicate the connection with the frontend. :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification :param clear_old_attr: remove old attribute value before setting the new one @@ -428,56 +405,59 @@ def set(self, name: str, service_in: Union[str, Tuple[str], List[str], 'Service' op_flow = copy.deepcopy(self) if copy_flow else self if name not in op_flow._service_nodes: - raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + raise FlowMissingNode('recv_from: %s can not be found in this Flow' % name) node = op_flow._service_nodes[name] service = node['service'] - if service_in: - service_in = op_flow._parse_service_endpoints(op_flow, name, service_in, connect_to_last_service=True) + if recv_from: + recv_from = op_flow._parse_service_endpoints(op_flow, name, recv_from, connect_to_last_service=True) if clear_old_attr: - node['incomes'] = service_in + node['incomes'] = recv_from # remove all edges point to this service for n in op_flow._service_nodes.values(): if name in n['outgoings']: n['outgoings'].remove(name) else: - node['incomes'] = node['incomes'].union(service_in) + node['incomes'] = node['incomes'].union(recv_from) # add it the new edge back - for s in service_in: + for s in recv_from: op_flow._service_nodes[s]['outgoings'].add(name) - if service_out: - service_out = op_flow._parse_service_endpoints(op_flow, name, service_out, connect_to_last_service=False) - node['outgoings'] = service_out + if send_to: + send_to = op_flow._parse_service_endpoints(op_flow, name, send_to, connect_to_last_service=False) + node['outgoings'] = send_to if clear_old_attr: # remove all edges this service point to for n in op_flow._service_nodes.values(): if name in n['incomes']: n['incomes'].remove(name) else: - node['outgoings'] = node['outgoings'].union(service_out) + node['outgoings'] = node['outgoings'].union(send_to) - for s in service_out: + for s in send_to: op_flow._service_nodes[s]['incomes'].add(name) if kwargs: if not clear_old_attr: node['kwargs'].update(kwargs) kwargs = node['kwargs'] - args, p_args = op_flow._get_parsed_args(op_flow, Flow._service2parser[service], kwargs) - node['args'] = args - node['parsed_args'] = p_args - node['kwargs'] = kwargs + args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs) + node.update({ + 'args': args, + 'parsed_args': p_args, + 'kwargs': kwargs, + 'unk_args': unk_args + }) if as_last_service: op_flow.set_last_service(name, False) # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @@ -493,7 +473,7 @@ def remove(self, name: str = None, copy_flow: bool = True) -> 'Flow': op_flow = copy.deepcopy(self) if copy_flow else self if name not in op_flow._service_nodes: - raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + raise FlowMissingNode('recv_from: %s can not be found in this Flow' % name) op_flow._service_nodes.pop(name) @@ -511,14 +491,14 @@ def remove(self, name: str = None, copy_flow: bool = True) -> 'Flow': # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow def add(self, service: Union['Service', str], name: str = None, - service_in: Union[str, Tuple[str], List[str], 'Service'] = None, - service_out: Union[str, Tuple[str], List[str], 'Service'] = None, + recv_from: Union[str, Tuple[str], List[str], 'Service'] = None, + send_to: Union[str, Tuple[str], List[str], 'Service'] = None, copy_flow: bool = True, **kwargs) -> 'Flow': """ @@ -526,11 +506,11 @@ def add(self, service: Union['Service', str], The attribute of the service can be later changed with :py:meth:`set` or deleted with :py:meth:`remove` :param service: a 'Service' enum or string, possible choices: Encoder, Router, Preprocessor, Indexer, Frontend - :param name: the name identifier of the service, can be used in 'service_in', - 'service_out', :py:meth:`set` and :py:meth:`remove`. - :param service_in: the name of the service(s) that this service receives data from. + :param name: the name identifier of the service, can be used in 'recv_from', + 'send_to', :py:meth:`set` and :py:meth:`remove`. + :param recv_from: the name of the service(s) that this service receives data from. One can also use 'Service.Frontend' to indicate the connection with the frontend. - :param service_out: the name of the service(s) that this service sends data to. + :param send_to: the name of the service(s) that this service sends data to. One can also use 'Service.Frontend' to indicate the connection with the frontend. :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification :param kwargs: other keyword-value arguments that the service CLI supports @@ -542,8 +522,8 @@ def add(self, service: Union['Service', str], if isinstance(service, str): service = Service.from_string(service) - if service not in Flow._service2parser: - raise ValueError('service: %s is not supported, should be one of %s' % (service, Flow._service2parser)) + if service not in service_map: + raise ValueError('service: %s is not supported, should be one of %s' % (service, service_map.keys())) if name in op_flow._service_nodes: raise FlowTopologyError('name: %s is used in this Flow already!' % name) @@ -558,36 +538,38 @@ def add(self, service: Union['Service', str], raise FlowTopologyError('frontend is already in this Flow') op_flow._frontend = name - service_in = op_flow._parse_service_endpoints(op_flow, name, service_in, connect_to_last_service=True) - service_out = op_flow._parse_service_endpoints(op_flow, name, service_out, connect_to_last_service=False) + recv_from = op_flow._parse_service_endpoints(op_flow, name, recv_from, connect_to_last_service=True) + send_to = op_flow._parse_service_endpoints(op_flow, name, send_to, connect_to_last_service=False) - args, p_args = op_flow._get_parsed_args(op_flow, Flow._service2parser[service], kwargs) + args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs) op_flow._service_nodes[name] = { 'service': service, 'parsed_args': p_args, 'args': args, - 'incomes': service_in, - 'outgoings': service_out, - 'kwargs': kwargs} + 'incomes': recv_from, + 'outgoings': send_to, + 'kwargs': kwargs, + 'unk_args': unk_args + } # direct all income services' output to the current service - for s in service_in: + for s in recv_from: op_flow._service_nodes[s]['outgoings'].add(name) - for s in service_out: + for s in send_to: op_flow._service_nodes[s]['incomes'].add(name) op_flow.set_last_service(name, False) # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @staticmethod def _parse_service_endpoints(op_flow, cur_service_name, service_endpoint, connect_to_last_service=False): - # parsing service_in + # parsing recv_from if isinstance(service_endpoint, str): service_endpoint = [service_endpoint] elif service_endpoint == Service.Frontend: @@ -602,9 +584,9 @@ def _parse_service_endpoints(op_flow, cur_service_name, service_endpoint, connec if s == cur_service_name: raise FlowTopologyError('the income of a service can not be itself') if s not in op_flow._service_nodes: - raise FlowMissingNode('service_in: %s can not be found in this Flow' % s) + raise FlowMissingNode('recv_from: %s can not be found in this Flow' % s) else: - raise ValueError('service_in=%s is not parsable' % service_endpoint) + raise ValueError('recv_from=%s is not parsable' % service_endpoint) return set(service_endpoint) @staticmethod @@ -632,7 +614,7 @@ def _get_parsed_args(op_flow, service_arg_parser, kwargs): except SystemExit: raise ValueError('bad arguments for service "%s", ' 'you may want to double check your args "%s"' % (service_arg_parser, args)) - return args, p_args + return args, p_args, unknown_args def _build_graph(self, copy_flow: bool) -> 'Flow': op_flow = copy.deepcopy(self) if copy_flow else self @@ -640,7 +622,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow': op_flow._service_edges.clear() if not op_flow._frontend: - raise FlowImcompleteError('frontend does not exist, you may need to add_frontend()') + raise FlowIncompleteError('frontend does not exist, you may need to add_frontend()') if not op_flow._last_changed_service or not op_flow._service_nodes: raise FlowTopologyError('flow is empty?') @@ -717,7 +699,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow': 'i can not determine the socket type' % ( len(edges_with_same_start), start_node, len(edges_with_same_end), end_node)) - op_flow._build_level = Flow.BuildLevel.GRAPH + op_flow._build_level = BuildLevel.GRAPH return op_flow def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *args, **kwargs) -> 'Flow': @@ -742,8 +724,8 @@ def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *arg # for thread and process backend which runs locally, host_in and host_out should not be set p_args.host_in = BaseService.default_host p_args.host_out = BaseService.default_host - op_flow._service_contexts.append((Flow._service2builder[v['service']], p_args)) - op_flow._build_level = Flow.BuildLevel.RUNTIME + op_flow._service_contexts.append((service_map[v['service']]['builder'], p_args)) + op_flow._build_level = BuildLevel.RUNTIME else: raise NotImplementedError('backend=%s is not supported yet' % backend) @@ -753,7 +735,7 @@ def __call__(self, *args, **kwargs): return self.build(*args, **kwargs) def __enter__(self): - if self._build_level.value < Flow.BuildLevel.RUNTIME.value: + if self._build_level.value < BuildLevel.RUNTIME.value: self.logger.warning( 'current build_level=%s, lower than required. ' 'build the flow now via build() with default parameters' % self._build_level) @@ -765,25 +747,13 @@ def __enter__(self): self.logger.critical('flow is built and ready, current build level is %s' % self._build_level) return self - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - def close(self): if hasattr(self, '_service_stack'): self._service_stack.close() - self._build_level = Flow.BuildLevel.EMPTY + self._build_level = BuildLevel.EMPTY self.logger.critical( 'flow is closed and all resources should be released already, current build level is %s' % self._build_level) - def __getstate__(self): - d = dict(self.__dict__) - del d['logger'] - return d - - def __setstate__(self, d): - self.__dict__.update(d) - self.logger = set_logger(self.__class__.__name__) - def __eq__(self, other): """ Comparing the topology of a flow with another flow. @@ -793,12 +763,12 @@ def __eq__(self, other): :return: """ - if self._build_level.value < Flow.BuildLevel.GRAPH.value: + if self._build_level.value < BuildLevel.GRAPH.value: a = self.build(backend=None, copy_flow=True) else: a = self - if other._build_level.value < Flow.BuildLevel.GRAPH.value: + if other._build_level.value < BuildLevel.GRAPH.value: b = other.build(backend=None, copy_flow=True) else: b = other diff --git a/gnes/flow/common.py b/gnes/flow/common.py new file mode 100644 index 00000000..0daaecd1 --- /dev/null +++ b/gnes/flow/common.py @@ -0,0 +1,19 @@ +from . import Flow + +# these base flow define common topology in real world +# however, they are not usable directly unless you set the correct `yaml_path` + +BaseIndexFlow = (Flow() + .add_preprocessor(name='prep', yaml_path='BasePreprocessor') + .add_encoder(name='enc', yaml_path='BaseEncoder') + .add_indexer(name='vec_idx', yaml_path='BaseIndexer') + .add_indexer(name='doc_idx', yaml_path='BaseIndexer', recv_from='prep') + .add_router(name='sync', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx'])) + +BaseQueryFlow = (Flow() + .add_preprocessor(name='prep', yaml_path='BasePreprocessor') + .add_encoder(name='enc', yaml_path='BaseEncoder') + .add_indexer(name='vec_idx', yaml_path='BaseIndexer') + .add_router(name='scorer', yaml_path='Chunk2DocTopkReducer') + .add_indexer(name='doc_idx', yaml_path='BaseIndexer')) diff --git a/gnes/flow/helper.py b/gnes/flow/helper.py new file mode 100644 index 00000000..3bef25e2 --- /dev/null +++ b/gnes/flow/helper.py @@ -0,0 +1,88 @@ +from functools import wraps + +from ..cli.parser import set_router_parser, set_indexer_parser, \ + set_frontend_parser, set_preprocessor_parser, \ + set_encoder_parser +from ..service.base import BetterEnum, ServiceManager +from ..service.encoder import EncoderService +from ..service.frontend import FrontendService +from ..service.indexer import IndexerService +from ..service.preprocessor import PreprocessorService +from ..service.router import RouterService + + +class BuildLevel(BetterEnum): + EMPTY = 0 + GRAPH = 1 + RUNTIME = 2 + + +class Service(BetterEnum): + Frontend = 0 + Encoder = 1 + Router = 2 + Indexer = 3 + Preprocessor = 4 + + +class FlowIncompleteError(ValueError): + """Exception when the flow missing some important component to run""" + + +class FlowTopologyError(ValueError): + """Exception when the topology is ambiguous""" + + +class FlowMissingNode(ValueError): + """Exception when the topology is ambiguous""" + + +class FlowBuildLevelMismatch(ValueError): + """Exception when required level is higher than the current build level""" + + +def build_required(required_level: 'BuildLevel'): + def __build_level(func): + @wraps(func) + def arg_wrapper(self, *args, **kwargs): + if hasattr(self, '_build_level'): + if self._build_level.value >= required_level.value: + return func(self, *args, **kwargs) + else: + raise FlowBuildLevelMismatch( + 'build_level check failed for %r, required level: %s, actual level: %s' % ( + func, required_level, self._build_level)) + else: + raise AttributeError('%r has no attribute "_build_level"' % self) + + return arg_wrapper + + return __build_level + + +service_map = { + Service.Encoder: { + 'parser': set_encoder_parser, + 'builder': lambda x: ServiceManager(EncoderService, x), + 'cmd': 'encode'}, + Service.Router: { + 'parser': set_router_parser, + 'builder': lambda x: ServiceManager(RouterService, x), + 'cmd': 'route', + }, + Service.Indexer: { + 'parser': set_indexer_parser, + 'builder': lambda x: ServiceManager(IndexerService, x), + 'cmd': 'index' + }, + Service.Frontend: { + 'parser': set_frontend_parser, + 'builder': FrontendService, + 'cmd': 'frontend' + }, + Service.Preprocessor: { + 'parser': set_preprocessor_parser, + 'builder': lambda x: ServiceManager(PreprocessorService, x), + 'cmd': 'preprocess' + } +} diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index d98ec37f..320662a4 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -3,6 +3,7 @@ from gnes.cli.parser import set_client_cli_parser from gnes.flow import Flow, Service as gfs, FlowBuildLevelMismatch +from gnes.flow.common import BaseIndexFlow, BaseQueryFlow class TestGNESFlow(unittest.TestCase): @@ -11,6 +12,7 @@ def setUp(self): self.dirname = os.path.dirname(__file__) self.test_file = os.path.join(self.dirname, 'sonnets_small.txt') self.yamldir = os.path.join(self.dirname, 'yaml') + self.dump_flow_path = os.path.join(self.dirname, 'test-flow.bin') self.index_args = set_client_cli_parser().parse_args([ '--mode', 'index', '--txt_file', self.test_file, @@ -29,7 +31,7 @@ def setUp(self): os.environ['TEST_WORKDIR'] = self.test_dir def tearDown(self): - for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]: + for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin, self.dump_flow_path]: if os.path.exists(k): os.remove(k) os.rmdir(self.test_dir) @@ -87,8 +89,8 @@ def test_flow2(self): def test_flow3(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, name='r0', service_out=gfs.Frontend, yaml_path='BaseRouter') - .add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter') + .add(gfs.Router, name='r0', send_to=gfs.Frontend, yaml_path='BaseRouter') + .add(gfs.Router, name='r1', recv_from=gfs.Frontend, yaml_path='BaseRouter') .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -96,8 +98,8 @@ def test_flow3(self): def test_flow4(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Router, name='r0', yaml_path='BaseRouter') - .add(gfs.Router, name='r1', service_in=gfs.Frontend, yaml_path='BaseRouter') - .add(gfs.Router, name='reduce', service_in=['r0', 'r1'], yaml_path='BaseRouter') + .add(gfs.Router, name='r1', recv_from=gfs.Frontend, yaml_path='BaseRouter') + .add(gfs.Router, name='reduce', recv_from=['r0', 'r1'], yaml_path='BaseRouter') .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -107,22 +109,22 @@ def test_flow5(self): .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') .add(gfs.Encoder, yaml_path='PyTorchTransformers') .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer') - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep') + .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep') .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, service_in=['vec_idx', 'doc_idx']) + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) - f.to_jpg() + # f.to_jpg() def test_flow_replica_pot(self): f = (Flow(check_version=False, route_table=True) .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4) .add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3) .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2) - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep', replicas=2) + .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, service_in=['vec_idx', 'doc_idx']) + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f.to_mermaid()) print(f.to_url(left_right=False)) @@ -137,9 +139,9 @@ def _test_index_flow(self, backend): .add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) .add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml')) .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'), - service_in='prep') + recv_from='prep') .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, service_in=['vec_idx', 'doc_idx'])) + num_part=2, recv_from=['vec_idx', 'doc_idx'])) with flow.build(backend=backend) as f: f.index(txt_file=self.test_file, batch_size=20) @@ -182,9 +184,9 @@ def test_flow_add_set(self): .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4) .add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3) .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2) - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep', replicas=2) + .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, service_in=['vec_idx', 'doc_idx']) + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f.to_url()) @@ -206,7 +208,7 @@ def test_flow_add_set(self): f2 = (f .set_last_service('vec_idx') .add_router('scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) - .set('doc_idx', service_in='scorer', yaml_path='DictIndexer', replicas=2, clear_old_attr=True) + .set('doc_idx', recv_from='scorer', yaml_path='DictIndexer', replicas=2, clear_old_attr=True) .remove('sync_barrier') .set_last_service('doc_idx') .build(backend=None)) @@ -219,3 +221,13 @@ def test_flow_add_set(self): print(f1.to_python_code()) print(f.to_python_code()) + + f1.dump(self.dump_flow_path) + f3 = Flow.load(self.dump_flow_path) + self.assertEqual(f1, f3) + + print(f1.to_swarm_yaml()) + + def test_common_flow(self): + print(BaseIndexFlow.build(backend=None).to_url()) + print(BaseQueryFlow.build(backend=None).to_url())