Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(base): fix gnes_config mixed in kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 25, 2019
1 parent 68c15fa commit c52c2cc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
19 changes: 12 additions & 7 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ def __call__(cls, *args, **kwargs):
# do _preload_package
getattr(cls, '_pre_init', lambda *x: None)()

if 'gnes_config' in kwargs:
gnes_config = kwargs.pop('gnes_config')
else:
gnes_config = {}

obj = type.__call__(cls, *args, **kwargs)

# set attribute
for k, v in TrainableType.default_gnes_config.items():
if k in kwargs:
v = kwargs[k]
if k in gnes_config:
v = gnes_config[k]
if not hasattr(obj, k):
setattr(obj, k, v)

Expand Down Expand Up @@ -295,9 +300,6 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {}))
load_from_dump = False
if dump_path:
Expand All @@ -314,14 +316,17 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {}))
obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p, **data.get('gnes_config', {}))
obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {}))

obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

return obj, data, load_from_dump

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions tests/test_load_dump_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def test_name_warning(self):
d2.name = ''
d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
d3.name = 'aa'
d3.name = 'dummy-pipeline'
d3.work_dir = './'
d3.dump()
d3.dump_yaml()
print('there should not be any warning after this line')
d31 = BaseEncoder.load_yaml(d3.yaml_full_path)
BaseEncoder.load_yaml(d3.yaml_full_path)

def test_dummytf(self):
d1 = DummyTFEncoder()
Expand Down

0 comments on commit c52c2cc

Please sign in to comment.