From ff7926d886f3e7a0e93b18fee9b3810b048d1443 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Mon, 26 Aug 2019 16:55:09 +0800 Subject: [PATCH] fix(encoder): fix eager execution --- gnes/encoder/numeric/pooling.py | 7 ++++-- tests/test_elmo_encoder.py | 41 --------------------------------- 2 files changed, 5 insertions(+), 43 deletions(-) delete mode 100644 tests/test_elmo_encoder.py diff --git a/gnes/encoder/numeric/pooling.py b/gnes/encoder/numeric/pooling.py index c92c6828..e1e263f3 100644 --- a/gnes/encoder/numeric/pooling.py +++ b/gnes/encoder/numeric/pooling.py @@ -28,7 +28,10 @@ def post_init(self): self.torch = torch elif self.backend == 'tensorflow': import tensorflow as tf - tf.enable_eager_execution() + try: + tf.enable_eager_execution() + except ValueError: + pass self.tf = tf def mul_mask(self, x, m): @@ -53,7 +56,7 @@ def masked_reduce_mean(self, x, m, jitter: float = 1e-10): self.torch.sum(m.unsqueeze(2), dim=1) + jitter) elif self.backend == 'tensorflow': return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / ( - self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) + self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) elif self.backend == 'numpy': return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter) diff --git a/tests/test_elmo_encoder.py b/tests/test_elmo_encoder.py deleted file mode 100644 index 1f33c4e2..00000000 --- a/tests/test_elmo_encoder.py +++ /dev/null @@ -1,41 +0,0 @@ -import os -import unittest - -from gnes.encoder.text.elmo import ElmoEncoder - - -@unittest.SkipTest -class TestElmoEncoder(unittest.TestCase): - - def setUp(self): - dirname = os.path.dirname(__file__) - self.dump_path = os.path.join(dirname, 'elmo_encoder.bin') - - self.test_str = [] - with open(os.path.join(dirname, 'tangshi.txt')) as f: - for line in f: - line = line.strip() - if line: - self.test_str.append(line) - - self.elmo_encoder = ElmoEncoder( - model_dir=os.environ.get('ELMO_CI_MODEL', '/zhs.model'), - pooling_strategy="REDUCE_MEAN") - - def test_encoding(self): - vec = self.elmo_encoder.encode(self.test_str) - self.assertEqual(vec.shape[0], len(self.test_str)) - self.assertEqual(vec.shape[1], 1024) - - def test_dump_load(self): - self.elmo_encoder.dump(self.dump_path) - - elmo_encoder2 = ElmoEncoder.load(self.dump_path) - - vec = elmo_encoder2.encode(self.test_str) - self.assertEqual(vec.shape[0], len(self.test_str)) - self.assertEqual(vec.shape[1], 1024) - - def tearDown(self): - if os.path.exists(self.dump_path): - os.remove(self.dump_path)