diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 872201b0..5945fd23 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -22,13 +22,13 @@ import tempfile import uuid from functools import wraps -from typing import Dict, Any, Union, TextIO, TypeVar, Type +from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable import ruamel.yaml.constructor from ..helper import set_logger, profiling, yaml, parse_arg, load_contrib_module -__all__ = ['TrainableBase'] +__all__ = ['TrainableBase', 'CompositionalTrainableBase'] T = TypeVar('T', bound='TrainableBase') @@ -295,7 +295,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): if stop_on_import_error: raise RuntimeError('Cannot import module, pip install may required') from ex - if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}: os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0' data = ruamel.yaml.constructor.SafeConstructor.construct_mapping( @@ -325,7 +325,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): obj.logger.info('initialize %s from a yaml config' % cls.__name__) cls.init_from_yaml = False - if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}: os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' return obj, data, load_from_dump @@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data): if p: r['gnes_config'] = p return r + + def _copy_from(self, x: 'TrainableBase') -> None: + pass + + +class CompositionalTrainableBase(TrainableBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._component = None # type: List[T] + + @property + def component(self) -> Union[List[T], Dict[str, T]]: + return self._component + + @property + def is_pipeline(self): + return isinstance(self.component, list) + + @component.setter + def component(self, comps: Callable[[], Union[list, dict]]): + if not callable(comps): + raise TypeError('component must be a callable function that returns ' + 'a List[BaseEncoder]') + if not getattr(self, 'init_from_yaml', False): + self._component = comps() + else: + self.logger.info('component is omitted from construction, ' + 'as it is initialized from yaml config') + + def close(self): + super().close() + # pipeline + if isinstance(self.component, list): + for be in self.component: + be.close() + # no typology + elif isinstance(self.component, dict): + for be in self.component.values(): + be.close() + elif self.component is None: + pass + else: + raise TypeError('component must be dict or list, received %s' % type(self.component)) + + def _copy_from(self, x: T): + if isinstance(self.component, list): + for be1, be2 in zip(self.component, x.component): + be1._copy_from(be2) + elif isinstance(self.component, dict): + for k, v in self.component.items(): + v._copy_from(x.component[k]) + else: + raise TypeError('component must be dict or list, received %s' % type(self.component)) + + @classmethod + def to_yaml(cls, representer, data): + tmp = super()._dump_instance_to_yaml(data) + tmp['component'] = data.component + return representer.represent_mapping('!' + cls.__name__, tmp) + + @classmethod + def from_yaml(cls, constructor, node): + obj, data, from_dump = super()._get_instance_from_yaml(constructor, node) + if not from_dump and 'component' in data: + obj.component = lambda: data['component'] + return obj diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index 7dab3078..154b70e8 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -54,7 +54,6 @@ def set_composer_parser(parser=None): 'gnes', '/'.join(('resources', 'config', 'compose', 'default.yml'))), help='yaml config of the service') parser.add_argument('--html_path', type=argparse.FileType('w', encoding='utf8'), - default='./gnes-board.html', help='output path of the HTML file, will contain all possible generations') parser.add_argument('--shell_path', type=argparse.FileType('w', encoding='utf8'), help='output path of the shell-based starting script') @@ -214,10 +213,11 @@ def set_grpc_frontend_parser(parser=None): from ..service.base import SocketType if not parser: parser = set_base_parser() - _set_client_parser(parser) + set_service_parser(parser) _set_grpc_parser(parser) parser.set_defaults(socket_in=SocketType.PULL_BIND, - socket_out=SocketType.PUSH_BIND) + socket_out=SocketType.PUSH_BIND, + read_only=True) parser.add_argument('--max_concurrency', type=int, default=10, help='maximum concurrent client allowed') parser.add_argument('--max_send_size', type=int, default=100, diff --git a/gnes/composer/base.py b/gnes/composer/base.py index 41e77f23..e6fe7137 100644 --- a/gnes/composer/base.py +++ b/gnes/composer/base.py @@ -156,10 +156,24 @@ def build_layers(self) -> List['YamlComposer.Layer']: last_layer = self._layers[idx - 1] for l in self._add_router(last_layer, layer): all_layers.append(copy.deepcopy(l)) - # # add frontend - # for l in self._add_router(all_layers[-1], all_layers[0]): - # all_layers.append(copy.deepcopy(l)) all_layers[0] = copy.deepcopy(self._layers[0]) + + # gRPCfrontend should always on the bind role + assert all_layers[0].is_single_component + assert all_layers[0].components[0]['name'] == 'gRPCFrontend' + + if all_layers[0].components[0]['socket_in'] == str(SocketType.SUB_CONNECT): + # change to sub bind + all_layers[0].components[0]['socket_in'] = str(SocketType.SUB_BIND) + for c in all_layers[-1].components: + c['socket_out'] = str(SocketType.PUB_CONNECT) + + if all_layers[0].components[0]['socket_in'] == str(SocketType.PULL_CONNECT): + # change to sub bind + all_layers[0].components[0]['socket_in'] = str(SocketType.PULL_BIND) + for c in all_layers[-1].components: + c['socket_out'] = str(SocketType.PUSH_CONNECT) + return all_layers @staticmethod @@ -292,11 +306,11 @@ def build_mermaid(all_layers: List['YamlComposer.Layer'], mermaid_leftright: boo # if len(last_layer.components) > 1: # self.mermaid_graph.append('\tend') - style = ['classDef gRPCFrontendCLS fill:#FFAA04,stroke:#277CE8,stroke-width:1px;', - 'classDef EncoderCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;', - 'classDef IndexerCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;', - 'classDef RouterCLS fill:#2BFFCB,stroke:#277CE8,stroke-width:1px;', - 'classDef PreprocessorCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;'] + style = ['classDef gRPCFrontendCLS fill:#FFE0E0,stroke:#FFE0E0,stroke-width:1px;', + 'classDef EncoderCLS fill:#FFDAAF,stroke:#FFDAAF,stroke-width:1px;', + 'classDef IndexerCLS fill:#FFFBC1,stroke:#FFFBC1,stroke-width:1px;', + 'classDef RouterCLS fill:#C9E8D2,stroke:#C9E8D2,stroke-width:1px;', + 'classDef PreprocessorCLS fill:#CEEEEF,stroke:#CEEEEF,stroke-width:1px;'] class_def = ['class %s %s;' % (','.join(v), k) for k, v in cls_dict.items()] mermaid_str = '\n'.join( ['graph %s' % ('LR' if mermaid_leftright else 'TD')] + mermaid_graph + style + class_def) diff --git a/gnes/composer/flask.py b/gnes/composer/flask.py index 36003967..fe132c11 100644 --- a/gnes/composer/flask.py +++ b/gnes/composer/flask.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import tempfile +import io from .base import YamlComposer from ..cli.parser import set_composer_parser @@ -37,23 +36,25 @@ def _create_flask_app(self): # support up to 10 concurrent HTTP requests app = Flask(__name__) + args = set_composer_parser().parse_args([]) + default_html = YamlComposer(args).build_all()['html'] @app.route('/', methods=['GET']) def _get_homepage(): - return YamlComposer(set_composer_parser().parse_args([])).build_all()['html'] + return default_html @app.route('/generate', methods=['POST']) def _regenerate(): data = request.form if request.form else request.json if not data or 'yaml-config' not in data: return '