diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index e283fcda..e731dea2 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -77,12 +77,17 @@ def __call__(cls, *args, **kwargs): # do _preload_package getattr(cls, '_pre_init', lambda *x: None)() + if 'gnes_config' in kwargs: + gnes_config = kwargs.pop('gnes_config') + else: + gnes_config = {} + obj = type.__call__(cls, *args, **kwargs) # set attribute for k, v in TrainableType.default_gnes_config.items(): - if k in kwargs: - v = kwargs[k] + if k in gnes_config: + v = gnes_config[k] if not hasattr(obj, k): setattr(obj, k, v) @@ -295,9 +300,6 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): data = ruamel.yaml.constructor.SafeConstructor.construct_mapping( constructor, node, deep=True) - if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: - os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' - dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {})) load_from_dump = False if dump_path: @@ -314,14 +316,17 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): # maybe there are some hanging kwargs in "parameter" tmp_a = (cls._convert_env_var(v) for v in a) tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()} - obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {})) + obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {})) else: tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()} - obj = cls(**tmp_p, **data.get('gnes_config', {})) + obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {})) obj.logger.info('initialize %s from a yaml config' % cls.__name__) cls.init_from_yaml = False + if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}: + os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' + return obj, data, load_from_dump @staticmethod diff --git a/tests/test_load_dump_pipeline.py b/tests/test_load_dump_pipeline.py index 0917f4ef..0a4d8143 100644 --- a/tests/test_load_dump_pipeline.py +++ b/tests/test_load_dump_pipeline.py @@ -44,10 +44,12 @@ def test_name_warning(self): d2.name = '' d3 = PipelineEncoder() d3.component = lambda: [d1, d2] - d3.name = 'aa' + d3.name = 'dummy-pipeline' + d3.work_dir = './' + d3.dump() d3.dump_yaml() print('there should not be any warning after this line') - d31 = BaseEncoder.load_yaml(d3.yaml_full_path) + BaseEncoder.load_yaml(d3.yaml_full_path) def test_dummytf(self): d1 = DummyTFEncoder()