diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 61162ea4..e283fcda 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -33,7 +33,6 @@ T = TypeVar('T', bound='TrainableBase') - def register_all_class(cls2file_map: Dict, module_name: str): import importlib for k, v in cls2file_map.items(): @@ -164,7 +163,7 @@ def __init__(self, *args, **kwargs): self._post_init_vars = set() def _post_init_wrapper(self): - if not getattr(self, 'name', None): + if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1': _id = str(uuid.uuid4()).split('-')[0] _name = '%s-%s' % (self.__class__.__name__, _id) self.logger.warning( @@ -290,9 +289,15 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): if stop_on_import_error: raise RuntimeError('Cannot import module, pip install may required') from ex + if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0' + data = ruamel.yaml.constructor.SafeConstructor.construct_mapping( constructor, node, deep=True) + if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' + dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {})) load_from_dump = False if dump_path: @@ -344,6 +349,3 @@ def _dump_instance_to_yaml(data): if p: r['gnes_config'] = p return r - - - diff --git a/tests/test_load_dump_pipeline.py b/tests/test_load_dump_pipeline.py index 4a02f74f..0917f4ef 100644 --- a/tests/test_load_dump_pipeline.py +++ b/tests/test_load_dump_pipeline.py @@ -37,6 +37,18 @@ def test_base(self): b = BaseEncoder.load_yaml(self.yaml_path) self.assertTrue(b.is_trained) + def test_name_warning(self): + d1 = DummyTFEncoder() + d2 = DummyTFEncoder() + d1.name = '' + d2.name = '' + d3 = PipelineEncoder() + d3.component = lambda: [d1, d2] + d3.name = 'aa' + d3.dump_yaml() + print('there should not be any warning after this line') + d31 = BaseEncoder.load_yaml(d3.yaml_full_path) + def test_dummytf(self): d1 = DummyTFEncoder() self.assertEqual(d1.encode(1), 2)