Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
tests(base): add unit test for load a dumped pipeline from yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 22, 2019
1 parent af7b2f8 commit 499682c
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/test_load_dump_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import os
import unittest

from gnes.encoder.base import BaseEncoder
from gnes.encoder.base import BaseEncoder, PipelineEncoder


class DummyTFEncoder(BaseEncoder):
def post_init(self):
import tensorflow as tf
self.a = tf.get_variable(name='a', shape=[])
self.sess = tf.Session()

def encode(self, a, *args):
return self.sess.run(self.a + 1, feed_dict={self.a: a})


class TestLoadDumpPipeline(unittest.TestCase):
Expand All @@ -22,5 +32,17 @@ def test_base(self):
b = BaseEncoder.load_yaml(self.yaml_path)
self.assertTrue(b.is_trained)

def test_dummytf(self):
d1 = DummyTFEncoder()
self.assertEqual(d1.encode(1), 2)

d2 = DummyTFEncoder()
self.assertEqual(d2.encode(2), 3)

d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
self.assertEqual(d2.encode(1), 3)

def tearDown(self):
os.remove(self.dump_path)
if os.path.exists(self.dump_path):
os.remove(self.dump_path)

0 comments on commit 499682c

Please sign in to comment.