From d4f69ef3a2216148605cecd67b4df5694e2e12bd Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Fri, 30 Aug 2019 16:44:22 +0800 Subject: [PATCH] test(pipeline): test pipeline load from yaml --- gnes/base/__init__.py | 8 +++++++- gnes/encoder/base.py | 4 +++- tests/contrib/dummy2.py | 13 +++++++++++++ tests/contrib/dummy3.py | 13 +++++++++++++ tests/test_pipeline_train.py | 2 ++ tests/test_pipeline_train_ext.py | 27 ++++++++++++++++++++++++++ tests/yaml/pipeline-multi-encoder2.yml | 10 ++++++++++ 7 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 tests/contrib/dummy2.py create mode 100644 tests/contrib/dummy3.py create mode 100644 tests/test_pipeline_train_ext.py create mode 100644 tests/yaml/pipeline-multi-encoder2.yml diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 5d30d0bf..2eb43ae7 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -109,6 +109,7 @@ def register_class(cls): setattr(cls, f_name, profiling(getattr(cls, f_name))) if getattr(cls, 'train', None): + # print('registered train func of %s'%cls) setattr(cls, 'train', TrainableType._as_train_func(getattr(cls, 'train'))) reg_cls_set.add(cls.__name__) @@ -124,7 +125,8 @@ def arg_wrapper(self, *args, **kwargs): self.logger.warning('"%s" has been trained already, ' 'training it again will override the previous training' % self.__class__.__name__) f = func(self, *args, **kwargs) - self.is_trained = True + if not isinstance(self, CompositionalTrainableBase): + self.is_trained = True return f return arg_wrapper @@ -387,6 +389,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._components = None # type: List[T] + @property + def is_trained(self): + return self.components and all(c.is_trained for c in self.components) + @property def components(self) -> Union[List[T], Dict[str, T]]: return self._components diff --git a/gnes/encoder/base.py b/gnes/encoder/base.py index dd3080d0..11b69391 100644 --- a/gnes/encoder/base.py +++ b/gnes/encoder/base.py @@ -80,6 +80,8 @@ def train(self, data, *args, **kwargs): if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): - be.train(data, *args, **kwargs) + if not be.is_trained: + be.train(data, *args, **kwargs) + if idx + 1 < len(self.components): data = be.encode(data, *args, **kwargs) diff --git a/tests/contrib/dummy2.py b/tests/contrib/dummy2.py new file mode 100644 index 00000000..eca2655d --- /dev/null +++ b/tests/contrib/dummy2.py @@ -0,0 +1,13 @@ +from gnes.component import BaseEncoder +from gnes.helper import train_required + + +class DummyEncoder2(BaseEncoder): + + def train(self, *args, **kwargs): + self.logger.info('you just trained me!') + pass + + @train_required + def encode(self, x): + return x + 1 diff --git a/tests/contrib/dummy3.py b/tests/contrib/dummy3.py new file mode 100644 index 00000000..238a0743 --- /dev/null +++ b/tests/contrib/dummy3.py @@ -0,0 +1,13 @@ +from gnes.component import BaseEncoder +from gnes.helper import train_required + + +class DummyEncoder3(BaseEncoder): + + def train(self, *args, **kwargs): + self.logger.info('you just trained me!') + pass + + @train_required + def encode(self, x): + return x + 1 diff --git a/tests/test_pipeline_train.py b/tests/test_pipeline_train.py index 4470701a..cca7719b 100644 --- a/tests/test_pipeline_train.py +++ b/tests/test_pipeline_train.py @@ -49,3 +49,5 @@ def test_load_yaml(self): self.assertRaises(RuntimeError, p.encode, 1) p.train(1) self.assertEqual(5, p.encode(1)) + p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder.yml')) + self.assertRaises(RuntimeError, p.encode, 1) diff --git a/tests/test_pipeline_train_ext.py b/tests/test_pipeline_train_ext.py new file mode 100644 index 00000000..9be1ba8a --- /dev/null +++ b/tests/test_pipeline_train_ext.py @@ -0,0 +1,27 @@ +import os +import unittest + +from gnes.encoder.base import BaseEncoder +from gnes.helper import PathImporter + + +class TestPipeTrain(unittest.TestCase): + def setUp(self): + self.dirname = os.path.dirname(__file__) + PathImporter.add_modules(*('{0}/contrib/dummy2.py,{0}/contrib/dummy3.py'.format(self.dirname).split(','))) + + def tearDown(self): + if os.path.exists('dummy-pipeline.bin'): + os.remove('dummy-pipeline.bin') + if os.path.exists('dummy-pipeline.yml'): + os.remove('dummy-pipeline.yml') + + def test_load_yaml(self): + p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder2.yml')) + self.assertFalse(p.is_trained) + self.assertRaises(RuntimeError, p.encode, 1) + p.train(1) + self.assertTrue(p.is_trained) + self.assertEqual(5, p.encode(1)) + p = BaseEncoder.load_yaml(os.path.join(self.dirname, 'yaml', 'pipeline-multi-encoder2.yml')) + self.assertRaises(RuntimeError, p.encode, 1) diff --git a/tests/yaml/pipeline-multi-encoder2.yml b/tests/yaml/pipeline-multi-encoder2.yml new file mode 100644 index 00000000..3ee23f27 --- /dev/null +++ b/tests/yaml/pipeline-multi-encoder2.yml @@ -0,0 +1,10 @@ +!PipelineEncoder +components: + - !DummyEncoder2 + gnes_config: + is_trained: true + - !DummyEncoder2 {} + - !DummyEncoder3 {} + - !DummyEncoder3 {} +gnes_config: + name: dummy-pipeline \ No newline at end of file