From 499682ce942c5fac778d8c09f40f95606439114d Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Mon, 22 Jul 2019 15:40:49 +0800 Subject: [PATCH] tests(base): add unit test for load a dumped pipeline from yaml --- tests/test_load_dump_pipeline.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/test_load_dump_pipeline.py b/tests/test_load_dump_pipeline.py index aa7b36b4..b080bfe8 100644 --- a/tests/test_load_dump_pipeline.py +++ b/tests/test_load_dump_pipeline.py @@ -1,7 +1,17 @@ import os import unittest -from gnes.encoder.base import BaseEncoder +from gnes.encoder.base import BaseEncoder, PipelineEncoder + + +class DummyTFEncoder(BaseEncoder): + def post_init(self): + import tensorflow as tf + self.a = tf.get_variable(name='a', shape=[]) + self.sess = tf.Session() + + def encode(self, a, *args): + return self.sess.run(self.a + 1, feed_dict={self.a: a}) class TestLoadDumpPipeline(unittest.TestCase): @@ -22,5 +32,17 @@ def test_base(self): b = BaseEncoder.load_yaml(self.yaml_path) self.assertTrue(b.is_trained) + def test_dummytf(self): + d1 = DummyTFEncoder() + self.assertEqual(d1.encode(1), 2) + + d2 = DummyTFEncoder() + self.assertEqual(d2.encode(2), 3) + + d3 = PipelineEncoder() + d3.component = lambda: [d1, d2] + self.assertEqual(d2.encode(1), 3) + def tearDown(self): - os.remove(self.dump_path) + if os.path.exists(self.dump_path): + os.remove(self.dump_path)