diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index b4d584b5..98764ccb 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -19,12 +19,12 @@ class Flow(TrainableBase): .. highlight:: python .. code-block:: python - from gnes.flow import Flow, Service as gfs + from gnes.flow import Flow f = (Flow(check_version=False, route_table=True) - .add(gfs.Preprocessor, yaml_path='BasePreprocessor') - .add(gfs.Encoder, yaml_path='BaseEncoder') - .add(gfs.Router, yaml_path='BaseRouter')) + .add_preprocessor(yaml_path='BasePreprocessor') + .add_encoder(yaml_path='BaseEncoder') + .add_router(yaml_path='BaseRouter')) with f.build(backend='thread') as flow: flow.index() @@ -40,6 +40,9 @@ class Flow(TrainableBase): """ + # a shortcut to the service frontend, removing one extra import + Frontend = Service.Frontend + def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, **kwargs): """ Create a new Flow object. @@ -506,6 +509,10 @@ def add(self, service: Union['Service', str], Add a service to the current flow object and return the new modified flow object. The attribute of the service can be later changed with :py:meth:`set` or deleted with :py:meth:`remove` + Note there are shortcut versions of this method. + Recommend to use :py:meth:`add_encoder`, :py:meth:`add_preprocessor`, + :py:meth:`add_router`, :py:meth:`add_indexer` whenever possible. + :param service: a 'Service' enum or string, possible choices: Encoder, Router, Preprocessor, Indexer, Frontend :param name: the name identifier of the service, can be used in 'recv_from', 'send_to', :py:meth:`set` and :py:meth:`remove`. diff --git a/gnes/flow/base.py b/gnes/flow/base.py new file mode 100644 index 00000000..7ee49d56 --- /dev/null +++ b/gnes/flow/base.py @@ -0,0 +1,36 @@ +from . import Flow + + +class BaseIndexFlow(Flow): + """ + BaseIndexFlow defines a common service pipeline when indexing. + + It can not be directly used as all services are using the base module by default. + You have to use :py:meth:`set` to change the `yaml_path` of each service. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (self.add_preprocessor(name='prep', yaml_path='BasePreprocessor', copy_flow=False) + .add_encoder(name='enc', yaml_path='BaseEncoder', copy_flow=False) + .add_indexer(name='vec_idx', yaml_path='BaseIndexer', copy_flow=False) + .add_indexer(name='doc_idx', yaml_path='BaseIndexer', recv_from='prep', copy_flow=False) + .add_router(name='sync', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx'], copy_flow=False)) + + +class BaseQueryFlow(Flow): + """ + BaseIndexFlow defines a common service pipeline when indexing. + + It can not be directly used as all services are using the base module by default. + You have to use :py:meth:`set` to change the `yaml_path` of each service. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (self.add_preprocessor(name='prep', yaml_path='BasePreprocessor', copy_flow=False) + .add_encoder(name='enc', yaml_path='BaseEncoder', copy_flow=False) + .add_indexer(name='vec_idx', yaml_path='BaseIndexer', copy_flow=False) + .add_router(name='scorer', yaml_path='Chunk2DocTopkReducer', copy_flow=False) + .add_indexer(name='doc_idx', yaml_path='BaseIndexer', copy_flow=False)) diff --git a/gnes/flow/common.py b/gnes/flow/common.py deleted file mode 100644 index 0daaecd1..00000000 --- a/gnes/flow/common.py +++ /dev/null @@ -1,19 +0,0 @@ -from . import Flow - -# these base flow define common topology in real world -# however, they are not usable directly unless you set the correct `yaml_path` - -BaseIndexFlow = (Flow() - .add_preprocessor(name='prep', yaml_path='BasePreprocessor') - .add_encoder(name='enc', yaml_path='BaseEncoder') - .add_indexer(name='vec_idx', yaml_path='BaseIndexer') - .add_indexer(name='doc_idx', yaml_path='BaseIndexer', recv_from='prep') - .add_router(name='sync', yaml_path='BaseReduceRouter', - num_part=2, recv_from=['vec_idx', 'doc_idx'])) - -BaseQueryFlow = (Flow() - .add_preprocessor(name='prep', yaml_path='BasePreprocessor') - .add_encoder(name='enc', yaml_path='BaseEncoder') - .add_indexer(name='vec_idx', yaml_path='BaseIndexer') - .add_router(name='scorer', yaml_path='Chunk2DocTopkReducer') - .add_indexer(name='doc_idx', yaml_path='BaseIndexer')) diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index 320662a4..0701edd2 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -2,8 +2,8 @@ import unittest from gnes.cli.parser import set_client_cli_parser -from gnes.flow import Flow, Service as gfs, FlowBuildLevelMismatch -from gnes.flow.common import BaseIndexFlow, BaseQueryFlow +from gnes.flow import Flow, FlowBuildLevelMismatch +from gnes.flow.base import BaseIndexFlow, BaseQueryFlow class TestGNESFlow(unittest.TestCase): @@ -38,15 +38,15 @@ def tearDown(self): def test_flow1(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, yaml_path='BaseRouter')) - g = f.add(gfs.Router, yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter')) + g = f.add_router(yaml_path='BaseRouter') print('f: %r g: %r' % (f, g)) g.build() print(g.to_mermaid()) - f = f.add(gfs.Router, yaml_path='BaseRouter') - g = g.add(gfs.Router, yaml_path='BaseRouter') + f = f.add_router(yaml_path='BaseRouter') + g = g.add_router(yaml_path='BaseRouter') print('f: %r g: %r' % (f, g)) f.build() @@ -55,63 +55,63 @@ def test_flow1(self): def test_flow1_ctx_empty(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, yaml_path='BaseRouter')) + .add_router(yaml_path='BaseRouter')) with f(backend='process'): pass def test_flow1_ctx(self): flow = (Flow(check_version=False, route_table=False) - .add(gfs.Router, yaml_path='BaseRouter')) + .add_router(yaml_path='BaseRouter')) with flow(backend='process', copy_flow=True) as f, open(self.test_file) as fp: f.index(txt_file=self.test_file, batch_size=4) f.train(txt_file=self.test_file, batch_size=4) with flow(backend='process', copy_flow=True) as f: # change the flow inside build shall fail - f = f.add(gfs.Router, yaml_path='BaseRouter') + f = f.add_router(yaml_path='BaseRouter') self.assertRaises(FlowBuildLevelMismatch, f.index, txt_file=self.test_file, batch_size=4) print(flow.build(backend=None).to_mermaid()) def test_flow2(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') - .add(gfs.Router, yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') + .add_router(yaml_path='BaseRouter') .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) def test_flow3(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, name='r0', send_to=gfs.Frontend, yaml_path='BaseRouter') - .add(gfs.Router, name='r1', recv_from=gfs.Frontend, yaml_path='BaseRouter') + .add_router(name='r0', send_to=Flow.Frontend, yaml_path='BaseRouter') + .add_router(name='r1', recv_from=Flow.Frontend, yaml_path='BaseRouter') .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) def test_flow4(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Router, name='r0', yaml_path='BaseRouter') - .add(gfs.Router, name='r1', recv_from=gfs.Frontend, yaml_path='BaseRouter') - .add(gfs.Router, name='reduce', recv_from=['r0', 'r1'], yaml_path='BaseRouter') + .add_router(name='r0', yaml_path='BaseRouter') + .add_router(name='r1', recv_from=Flow.Frontend, yaml_path='BaseRouter') + .add_router(name='reduce', recv_from=['r0', 'r1'], yaml_path='BaseRouter') .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) def test_flow5(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') - .add(gfs.Encoder, yaml_path='PyTorchTransformers') - .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer') - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep') - .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, recv_from=['vec_idx', 'doc_idx']) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor') + .add_encoder(yaml_path='PyTorchTransformers') + .add_indexer(name='vec_idx', yaml_path='NumpyIndexer') + .add_indexer(name='doc_idx', yaml_path='DictIndexer', recv_from='prep') + .add_router(name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) @@ -119,12 +119,12 @@ def test_flow5(self): def test_flow_replica_pot(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4) - .add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3) - .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2) - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) - .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, recv_from=['vec_idx', 'doc_idx']) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor', replicas=4) + .add_encoder(yaml_path='PyTorchTransformers', replicas=3) + .add_indexer(name='vec_idx', yaml_path='NumpyIndexer', replicas=2) + .add_indexer(name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) + .add_router(name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f.to_mermaid()) print(f.to_url(left_right=False)) @@ -135,13 +135,13 @@ def _test_index_flow(self, backend): self.assertFalse(os.path.exists(k)) flow = (Flow(check_version=False, route_table=False) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') - .add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) - .add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml')) - .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'), - recv_from='prep') - .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, recv_from=['vec_idx', 'doc_idx'])) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor') + .add_encoder(yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) + .add_indexer(name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml')) + .add_indexer(name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'), + recv_from='prep') + .add_router(name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx'])) with flow.build(backend=backend) as f: f.index(txt_file=self.test_file, batch_size=20) @@ -151,11 +151,11 @@ def _test_index_flow(self, backend): def _test_query_flow(self, backend): flow = (Flow(check_version=False, route_table=False) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') - .add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) - .add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml')) - .add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) - .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor') + .add_encoder(yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) + .add_indexer(name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml')) + .add_router(name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) + .add_indexer(name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) with flow.build(backend=backend) as f, open(self.test_file, encoding='utf8') as fp: f.query(bytes_gen=[v.encode() for v in fp][:3]) @@ -171,22 +171,22 @@ def test_indexe_query_flow_proc(self): def test_query_flow_plot(self): flow = (Flow(check_version=False, route_table=False) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=2) - .add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) - .add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'), - replicas=4) - .add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) - .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor', replicas=2) + .add_encoder(yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3) + .add_indexer(name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'), + replicas=4) + .add_router(name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) + .add_indexer(name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) print(flow.build(backend=None).to_url()) def test_flow_add_set(self): f = (Flow(check_version=False, route_table=True) - .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4) - .add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3) - .add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2) - .add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) - .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', - num_part=2, recv_from=['vec_idx', 'doc_idx']) + .add_preprocessor(name='prep', yaml_path='SentSplitPreprocessor', replicas=4) + .add_encoder(yaml_path='PyTorchTransformers', replicas=3) + .add_indexer(name='vec_idx', yaml_path='NumpyIndexer', replicas=2) + .add_indexer(name='doc_idx', yaml_path='DictIndexer', recv_from='prep', replicas=2) + .add_router(name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, recv_from=['vec_idx', 'doc_idx']) .build(backend=None)) print(f.to_url()) @@ -229,5 +229,5 @@ def test_flow_add_set(self): print(f1.to_swarm_yaml()) def test_common_flow(self): - print(BaseIndexFlow.build(backend=None).to_url()) - print(BaseQueryFlow.build(backend=None).to_url()) + print(BaseIndexFlow().build(backend=None).to_url()) + print(BaseQueryFlow().build(backend=None).to_url())