Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(base): fix redundant warning in pipeline encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 25, 2019
1 parent aadeeef commit 68c15fa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 7 additions & 5 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
T = TypeVar('T', bound='TrainableBase')



def register_all_class(cls2file_map: Dict, module_name: str):
import importlib
for k, v in cls2file_map.items():
Expand Down Expand Up @@ -164,7 +163,7 @@ def __init__(self, *args, **kwargs):
self._post_init_vars = set()

def _post_init_wrapper(self):
if not getattr(self, 'name', None):
if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1':
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
Expand Down Expand Up @@ -290,9 +289,15 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
if stop_on_import_error:
raise RuntimeError('Cannot import module, pip install may required') from ex

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0'

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {}))
load_from_dump = False
if dump_path:
Expand Down Expand Up @@ -344,6 +349,3 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r



12 changes: 12 additions & 0 deletions tests/test_load_dump_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def test_base(self):
b = BaseEncoder.load_yaml(self.yaml_path)
self.assertTrue(b.is_trained)

def test_name_warning(self):
d1 = DummyTFEncoder()
d2 = DummyTFEncoder()
d1.name = ''
d2.name = ''
d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
d3.name = 'aa'
d3.dump_yaml()
print('there should not be any warning after this line')
d31 = BaseEncoder.load_yaml(d3.yaml_full_path)

def test_dummytf(self):
d1 = DummyTFEncoder()
self.assertEqual(d1.encode(1), 2)
Expand Down

0 comments on commit 68c15fa

Please sign in to comment.