Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
test(pipeline): test pipeline load from yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 30, 2019
1 parent 8524b1f commit d4f69ef
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 2 deletions.
8 changes: 7 additions & 1 deletion gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def register_class(cls):
setattr(cls, f_name, profiling(getattr(cls, f_name)))

if getattr(cls, 'train', None):
# print('registered train func of %s'%cls)
setattr(cls, 'train', TrainableType._as_train_func(getattr(cls, 'train')))

reg_cls_set.add(cls.__name__)
Expand All @@ -124,7 +125,8 @@ def arg_wrapper(self, *args, **kwargs):
self.logger.warning('"%s" has been trained already, '
'training it again will override the previous training' % self.__class__.__name__)
f = func(self, *args, **kwargs)
self.is_trained = True
if not isinstance(self, CompositionalTrainableBase):
self.is_trained = True
return f

return arg_wrapper
Expand Down Expand Up @@ -387,6 +389,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._components = None # type: List[T]

@property
def is_trained(self):
return self.components and all(c.is_trained for c in self.components)

@property
def components(self) -> Union[List[T], Dict[str, T]]:
return self._components
Expand Down
4 changes: 3 additions & 1 deletion gnes/encoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def train(self, data, *args, **kwargs):
if not self.components:
raise NotImplementedError
for idx, be in enumerate(self.components):
be.train(data, *args, **kwargs)
if not be.is_trained:
be.train(data, *args, **kwargs)

if idx + 1 < len(self.components):
data = be.encode(data, *args, **kwargs)
13 changes: 13 additions & 0 deletions tests/contrib/dummy2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from gnes.component import BaseEncoder
from gnes.helper import train_required


class DummyEncoder2(BaseEncoder):

def train(self, *args, **kwargs):
self.logger.info('you just trained me!')
pass

@train_required
def encode(self, x):
return x + 1
13 changes: 13 additions & 0 deletions tests/contrib/dummy3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from gnes.component import BaseEncoder
from gnes.helper import train_required


class DummyEncoder3(BaseEncoder):

def train(self, *args, **kwargs):
self.logger.info('you just trained me!')
pass

@train_required
def encode(self, x):
return x + 1
2 changes: 2 additions & 0 deletions tests/test_pipeline_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ def test_load_yaml(self):
self.assertRaises(RuntimeError, p.encode, 1)
p.train(1)
self.assertEqual(5, p.encode(1))
p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder.yml'))
self.assertRaises(RuntimeError, p.encode, 1)
27 changes: 27 additions & 0 deletions tests/test_pipeline_train_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest

from gnes.encoder.base import BaseEncoder
from gnes.helper import PathImporter


class TestPipeTrain(unittest.TestCase):
def setUp(self):
self.dirname = os.path.dirname(__file__)
PathImporter.add_modules(*('{0}/contrib/dummy2.py,{0}/contrib/dummy3.py'.format(self.dirname).split(',')))

def tearDown(self):
if os.path.exists('dummy-pipeline.bin'):
os.remove('dummy-pipeline.bin')
if os.path.exists('dummy-pipeline.yml'):
os.remove('dummy-pipeline.yml')

def test_load_yaml(self):
p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder2.yml'))
self.assertFalse(p.is_trained)
self.assertRaises(RuntimeError, p.encode, 1)
p.train(1)
self.assertTrue(p.is_trained)
self.assertEqual(5, p.encode(1))
p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder2.yml'))
self.assertRaises(RuntimeError, p.encode, 1)
10 changes: 10 additions & 0 deletions tests/yaml/pipeline-multi-encoder2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
!PipelineEncoder
components:
- !DummyEncoder2
gnes_config:
is_trained: true
- !DummyEncoder2 {}
- !DummyEncoder3 {}
- !DummyEncoder3 {}
gnes_config:
name: dummy-pipeline

0 comments on commit d4f69ef

Please sign in to comment.