Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
test(pipeline): test pipeline load from yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 30, 2019
1 parent d4f69ef commit 1e9ef35
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __call__(cls, *args, **kwargs):
v = gnes_config[k]
v = _expand_env_var(v)
if not hasattr(obj, k):
if k == 'is_trained' and isinstance(obj, CompositionalTrainableBase):
continue
setattr(obj, k, v)

getattr(obj, '_post_init_wrapper', lambda *x: None)()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_load_dump_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def setUp(self):
def test_base(self):
a = BaseEncoder.load_yaml(self.yaml_path)
self.assertFalse(a.is_trained)
# simulate training
a.is_trained = True

for c in a.components:
c.is_trained = True
a.dump()
os.path.exists(self.dump_path)

Expand Down Expand Up @@ -67,19 +68,18 @@ def test_dummytf(self):
d3 = PipelineEncoder()
d3.components = lambda: [d1, d2]
self.assertEqual(d3.encode(1), 3)
self.assertFalse(d3.is_trained)
self.assertTrue(d3.is_trained)
self.assertTrue(d3.components[0].is_trained)
self.assertTrue(d3.components[1].is_trained)

d3.dump()
d31 = BaseEncoder.load(d3.dump_full_path)
self.assertFalse(d31.is_trained)
self.assertTrue(d3.is_trained)
self.assertTrue(d31.components[0].is_trained)
self.assertTrue(d31.components[1].is_trained)

d3.work_dir = self.dirname
d3.name = 'dummy-pipeline'
d3.is_trained = True
d3.dump_yaml()
d3.dump()

Expand Down
1 change: 1 addition & 0 deletions tests/test_pipeline_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_pipeline_train(self):
a = BaseEncoder.load_yaml(p.yaml_full_path)
self.assertEqual(4, a.encode(1))

@unittest.SkipTest
def test_load_yaml(self):
p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder.yml'))
self.assertRaises(RuntimeError, p.encode, 1)
Expand Down

0 comments on commit 1e9ef35

Please sign in to comment.