From b94490f17baf78c871478b2f7f68bffc28bf9393 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Mon, 14 Oct 2019 11:42:55 +0800 Subject: [PATCH] feat(flow): allow add service to be str --- gnes/flow/__init__.py | 7 +++++-- gnes/service/base.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 6572feb6..4dd41bf9 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -320,7 +320,7 @@ def add_router(self, *args, **kwargs) -> 'Flow': """Add a router to the current flow, a shortcut of :py:meth:`add(Service.Router)`""" return self.add(Service.Router, *args, **kwargs) - def add(self, service: 'Service', + def add(self, service: Union['Service', str], name: str = None, service_in: Union[str, Tuple[str], List[str], 'Service'] = None, service_out: Union[str, Tuple[str], List[str], 'Service'] = None, @@ -329,7 +329,7 @@ def add(self, service: 'Service', """ Add a service to the current flow object and return the new modified flow object - :param service: a 'Service' enum, possible choices: Encoder, Router, Preprocessor, Indexer, Frontend + :param service: a 'Service' enum or string, possible choices: Encoder, Router, Preprocessor, Indexer, Frontend :param name: the name indentifier of the service, useful in 'service_in' and 'service_out' :param service_in: the name of the service(s) that this service receives data from. One can also use 'Service.Frontend' to indicate the connection with the frontend. @@ -342,6 +342,9 @@ def add(self, service: 'Service', op_flow = copy.deepcopy(self) if copy_flow else self + if isinstance(service, str): + service = Service.from_string(service) + if service not in Flow._service2parser: raise ValueError('service: %s is not supported, should be one of %s' % (service, Flow._service2parser)) diff --git a/gnes/service/base.py b/gnes/service/base.py index 7da06d83..c9dcec75 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -42,7 +42,7 @@ def from_string(cls, s): try: return cls[s] except KeyError: - raise ValueError() + raise ValueError('%s is not a valid enum for %s' % (s, cls)) class ReduceOp(BetterEnum):