diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 61162ea4..e731dea2 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -33,7 +33,6 @@ T = TypeVar('T', bound='TrainableBase') - def register_all_class(cls2file_map: Dict, module_name: str): import importlib for k, v in cls2file_map.items(): @@ -78,12 +77,17 @@ def __call__(cls, *args, **kwargs): # do _preload_package getattr(cls, '_pre_init', lambda *x: None)() + if 'gnes_config' in kwargs: + gnes_config = kwargs.pop('gnes_config') + else: + gnes_config = {} + obj = type.__call__(cls, *args, **kwargs) # set attribute for k, v in TrainableType.default_gnes_config.items(): - if k in kwargs: - v = kwargs[k] + if k in gnes_config: + v = gnes_config[k] if not hasattr(obj, k): setattr(obj, k, v) @@ -164,7 +168,7 @@ def __init__(self, *args, **kwargs): self._post_init_vars = set() def _post_init_wrapper(self): - if not getattr(self, 'name', None): + if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1': _id = str(uuid.uuid4()).split('-')[0] _name = '%s-%s' % (self.__class__.__name__, _id) self.logger.warning( @@ -290,6 +294,9 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): if stop_on_import_error: raise RuntimeError('Cannot import module, pip install may required') from ex + if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0' + data = ruamel.yaml.constructor.SafeConstructor.construct_mapping( constructor, node, deep=True) @@ -309,14 +316,17 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): # maybe there are some hanging kwargs in "parameter" tmp_a = (cls._convert_env_var(v) for v in a) tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()} - obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {})) + obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {})) else: tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()} - obj = cls(**tmp_p, **data.get('gnes_config', {})) + obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {})) obj.logger.info('initialize %s from a yaml config' % cls.__name__) cls.init_from_yaml = False + if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' + return obj, data, load_from_dump @staticmethod @@ -344,6 +354,3 @@ def _dump_instance_to_yaml(data): if p: r['gnes_config'] = p return r - - - diff --git a/gnes/composer/base.py b/gnes/composer/base.py index 548c1b44..89797008 100644 --- a/gnes/composer/base.py +++ b/gnes/composer/base.py @@ -373,8 +373,9 @@ def rule3(): 'socket_out': str(SocketType.PUSH_BIND), 'port_in': last_layer.components[0]['port_out'], 'port_out': self._get_random_port()}) - layer.components[0]['socket_in'] = str(SocketType.PULL_CONNECT) - layer.components[0]['port_in'] = r['port_out'] + for c in layer.components: + c['socket_in'] = str(SocketType.PULL_CONNECT) + c['port_in'] = r['port_out'] router_layer.append(r) router_layers.append(router_layer) @@ -458,7 +459,6 @@ def rule10(): c['port_in'] = r0['port_out'] def rule8(): - last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT) router_layer = YamlComposer.Layer(layer_id=self._num_layer) self._num_layer += 1 r = CommentedMap({'name': 'Router', @@ -503,11 +503,33 @@ def rule9(): last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT) layer.components[0]['socket_in'] = str(SocketType.PULL_BIND) + def rule11(): + # a shortcut fn: (N)-2-(N) with push pull connection + router_layer = YamlComposer.Layer(layer_id=self._num_layer) + self._num_layer += 1 + r = CommentedMap({'name': 'Router', + 'yaml_path': None, + 'socket_in': str(SocketType.PULL_BIND), + 'socket_out': str(SocketType.PUSH_BIND), + 'port_in': self._get_random_port(), + 'port_out': self._get_random_port()}) + + for c in last_layer.components: + c['socket_out'] = str(SocketType.PUSH_CONNECT) + c['port_out'] = r['port_in'] + for c in layer.components: + c['socket_in'] = str(SocketType.PULL_CONNECT) + c['port_in'] = r['port_out'] + router_layer.append(r) + router_layers.append(router_layer) + router_layers = [] # type: List['self.Layer'] # bind the last out to current in - last_layer.components[0]['port_out'] = self._get_random_port() - layer.components[0]['port_in'] = last_layer.components[0]['port_out'] + if last_layer.is_single_component: + last_layer.components[0]['port_out'] = self._get_random_port() + for c in layer.components: + c['port_in'] = last_layer.components[0]['port_out'] # 1-to-? if layer.is_single_component: # 1-to-(1) @@ -530,8 +552,13 @@ def rule9(): rule6() elif last_layer.is_homo_multi_component: # (N)-to-? + last_layer.components[0]['port_out'] = self._get_random_port() + last_income = self.Layer.get_value(last_layer.components[0], 'income') + for c in layer.components: + c['port_in'] = last_layer.components[0]['port_out'] + if layer.is_single_component: if last_income == 'pull': # (N)-to-1 @@ -560,7 +587,7 @@ def rule9(): else: raise NotImplementedError('replica type: %s is not recognized!' % last_income) elif last_layer.is_heto_single_component: - rule3() + rule8() else: rule8() return [last_layer, *router_layers] diff --git a/tests/test_load_dump_pipeline.py b/tests/test_load_dump_pipeline.py index 4a02f74f..0a4d8143 100644 --- a/tests/test_load_dump_pipeline.py +++ b/tests/test_load_dump_pipeline.py @@ -37,6 +37,20 @@ def test_base(self): b = BaseEncoder.load_yaml(self.yaml_path) self.assertTrue(b.is_trained) + def test_name_warning(self): + d1 = DummyTFEncoder() + d2 = DummyTFEncoder() + d1.name = '' + d2.name = '' + d3 = PipelineEncoder() + d3.component = lambda: [d1, d2] + d3.name = 'dummy-pipeline' + d3.work_dir = './' + d3.dump() + d3.dump_yaml() + print('there should not be any warning after this line') + BaseEncoder.load_yaml(d3.yaml_full_path) + def test_dummytf(self): d1 = DummyTFEncoder() self.assertEqual(d1.encode(1), 2)