diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 4dd41bf9..cf4f09a3 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -33,6 +33,10 @@ class FlowTopologyError(ValueError): """Exception when the topology is ambiguous""" +class FlowMissingNode(ValueError): + """Exception when the topology is ambiguous""" + + class FlowBuildLevelMismatch(ValueError): """Exception when required level is higher than the current build level""" @@ -113,7 +117,7 @@ def __init__(self, with_frontend: bool = True, **kwargs): self._service_edges = {} self._service_name_counter = {k: 0 for k in Flow._service2parser.keys()} self._service_contexts = [] - self._last_add_service = None + self._last_changed_service = [] self._common_kwargs = kwargs self._frontend = None self._client = None @@ -320,6 +324,140 @@ def add_router(self, *args, **kwargs) -> 'Flow': """Add a router to the current flow, a shortcut of :py:meth:`add(Service.Router)`""" return self.add(Service.Router, *args, **kwargs) + def set_last_service(self, name: str, copy_flow: bool = True) -> 'Flow': + """ + Set a service as the last service in the flow, useful when modifying the flow. + + :param name: the name of the existing service + :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification + :return: a (new) flow object with modification + """ + op_flow = copy.deepcopy(self) if copy_flow else self + + if name not in op_flow._service_nodes: + raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + + if op_flow._last_changed_service and name == op_flow._last_changed_service[-1]: + pass + else: + op_flow._last_changed_service.append(name) + + # graph is now changed so we need to + # reset the build level to the lowest + op_flow._build_level = Flow.BuildLevel.EMPTY + + return op_flow + + def set(self, name: str, service_in: Union[str, Tuple[str], List[str], 'Service'] = None, + service_out: Union[str, Tuple[str], List[str], 'Service'] = None, + copy_flow: bool = True, + clear_old_attr: bool = False, + as_last_service: bool = False, + **kwargs) -> 'Flow': + """ + Set the attribute of an existing service (added by :py:meth:`add`) in the flow. + For the attributes or kwargs that aren't given, they will remain unchanged as before. + + :param name: the name of the existing service + :param service_in: the name of the service(s) that this service receives data from. + One can also use 'Service.Frontend' to indicate the connection with the frontend. + :param service_out: the name of the service(s) that this service sends data to. + One can also use 'Service.Frontend' to indicate the connection with the frontend. + :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification + :param clear_old_attr: remove old attribute value before setting the new one + :param as_last_service: whether setting the changed service as the last service in the graph + :param kwargs: other keyword-value arguments that the service CLI supports + :return: a (new) flow object with modification + """ + op_flow = copy.deepcopy(self) if copy_flow else self + + if name not in op_flow._service_nodes: + raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + + node = op_flow._service_nodes[name] + service = node['service'] + + if service_in: + service_in = op_flow._parse_service_endpoints(op_flow, name, service_in, connect_to_last_service=True) + + if clear_old_attr: + node['incomes'] = service_in + # remove all edges point to this service + for n in op_flow._service_nodes.values(): + if name in n['outgoings']: + n['outgoings'].remove(name) + else: + node['incomes'] = node['incomes'].union(service_in) + + # add it the new edge back + for s in service_in: + op_flow._service_nodes[s]['outgoings'].add(name) + + if service_out: + service_out = op_flow._parse_service_endpoints(op_flow, name, service_out, connect_to_last_service=False) + node['outgoings'] = service_out + if clear_old_attr: + # remove all edges this service point to + for n in op_flow._service_nodes.values(): + if name in n['incomes']: + n['incomes'].remove(name) + else: + node['outgoings'] = node['outgoings'].union(service_out) + + for s in service_out: + op_flow._service_nodes[s]['incomes'].add(name) + + if kwargs: + if not clear_old_attr: + node['kwargs'].update(kwargs) + kwargs = node['kwargs'] + args, p_args = op_flow._get_parsed_args(op_flow, Flow._service2parser[service], kwargs) + node['args'] = args + node['parsed_args'] = p_args + + if as_last_service: + op_flow.set_last_service(name, False) + + # graph is now changed so we need to + # reset the build level to the lowest + op_flow._build_level = Flow.BuildLevel.EMPTY + + return op_flow + + def remove(self, name: str = None, copy_flow: bool = True) -> 'Flow': + """ + Remove a service from the flow. + + :param name: the name of the existing service + :param copy_flow: when set to true, then always copy the current flow and do the modification on top of it then return, otherwise, do in-line modification + :return: a (new) flow object with modification + """ + + op_flow = copy.deepcopy(self) if copy_flow else self + + if name not in op_flow._service_nodes: + raise FlowMissingNode('service_in: %s can not be found in this Flow' % name) + + op_flow._service_nodes.pop(name) + + # remove all edges point to this service + for n in op_flow._service_nodes.values(): + if name in n['outgoings']: + n['outgoings'].remove(name) + if name in n['incomes']: + n['incomes'].remove(name) + + if op_flow._service_nodes: + op_flow._last_changed_service = [v for v in op_flow._last_changed_service if v != name] + else: + op_flow._last_changed_service = [] + + # graph is now changed so we need to + # reset the build level to the lowest + op_flow._build_level = Flow.BuildLevel.EMPTY + + return op_flow + def add(self, service: Union['Service', str], name: str = None, service_in: Union[str, Tuple[str], List[str], 'Service'] = None, @@ -327,10 +465,12 @@ def add(self, service: Union['Service', str], copy_flow: bool = True, **kwargs) -> 'Flow': """ - Add a service to the current flow object and return the new modified flow object + Add a service to the current flow object and return the new modified flow object. + The attribute of the service can be later changed with :py:meth:`set` or deleted with :py:meth:`remove` :param service: a 'Service' enum or string, possible choices: Encoder, Router, Preprocessor, Indexer, Frontend - :param name: the name indentifier of the service, useful in 'service_in' and 'service_out' + :param name: the name identifier of the service, can be used in 'service_in', + 'service_out', :py:meth:`set` and :py:meth:`remove`. :param service_in: the name of the service(s) that this service receives data from. One can also use 'Service.Frontend' to indicate the connection with the frontend. :param service_out: the name of the service(s) that this service sends data to. @@ -371,7 +511,8 @@ def add(self, service: Union['Service', str], 'parsed_args': p_args, 'args': args, 'incomes': service_in, - 'outgoings': service_out} + 'outgoings': service_out, + 'kwargs': kwargs} # direct all income services' output to the current service for s in service_in: @@ -379,7 +520,7 @@ def add(self, service: Union['Service', str], for s in service_out: op_flow._service_nodes[s]['incomes'].add(name) - op_flow._last_add_service = name + op_flow.set_last_service(name, False) # graph is now changed so we need to # reset the build level to the lowest @@ -395,8 +536,8 @@ def _parse_service_endpoints(op_flow, cur_service_name, service_endpoint, connec elif service_endpoint == Service.Frontend: service_endpoint = [op_flow._frontend] elif not service_endpoint: - if op_flow._last_add_service and connect_to_last_service: - service_endpoint = [op_flow._last_add_service] + if op_flow._last_changed_service and connect_to_last_service: + service_endpoint = [op_flow._last_changed_service[-1]] else: service_endpoint = [] if isinstance(service_endpoint, list) or isinstance(service_endpoint, tuple): @@ -404,7 +545,7 @@ def _parse_service_endpoints(op_flow, cur_service_name, service_endpoint, connec if s == cur_service_name: raise FlowTopologyError('the income of a service can not be itself') if s not in op_flow._service_nodes: - raise FlowTopologyError('service_in: %s can not be found in this Flow' % s) + raise FlowMissingNode('service_in: %s can not be found in this Flow' % s) else: raise ValueError('service_in=%s is not parsable' % service_endpoint) return set(service_endpoint) @@ -444,11 +585,11 @@ def _build_graph(self, copy_flow: bool) -> 'Flow': if not op_flow._frontend: raise FlowImcompleteError('frontend does not exist, you may need to add_frontend()') - if not op_flow._last_add_service or not op_flow._service_nodes: + if not op_flow._last_changed_service or not op_flow._service_nodes: raise FlowTopologyError('flow is empty?') # close the loop - op_flow._service_nodes[op_flow._frontend]['incomes'].add(op_flow._last_add_service) + op_flow._service_nodes[op_flow._frontend]['incomes'].add(op_flow._last_changed_service[-1]) # build all edges for k, v in op_flow._service_nodes.items(): diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index c7d962a7..3f8f63ca 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -176,3 +176,39 @@ def test_query_flow_plot(self): .add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) print(flow.build(backend=None).to_url()) + + def test_flow_add_set(self): + f = (Flow(check_version=False, route_table=True) + .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4) + .add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3) + .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2) + .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep', replicas=2) + .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, service_in=['vec_idx', 'doc_idx']) + .build(backend=None)) + + print(f.to_url()) + print(f.set('prep', replicas=1).build(backend=None).to_url()) + # make it as query flow + + f1 = (f + .remove('sync_barrier') + .remove('doc_idx') + .set_last_service('vec_idx') + .add_router('scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) + .add_indexer('doc_idx', yaml_path='DictIndexer', replicas=2) + .build(backend=None)) + + print(f1.to_url()) + + # another way to convert f to an index flow + + f2 = (f + .set_last_service('vec_idx') + .add_router('scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) + .set('doc_idx', service_in='scorer', yaml_path='DictIndexer', replicas=2, clear_old_attr=True) + .remove('sync_barrier') + .set_last_service('doc_idx') + .build(backend=None)) + + print(f2.to_url())