diff --git a/gnes/encoder/__init__.py b/gnes/encoder/__init__.py index 5c7a89c4..46dd6267 100644 --- a/gnes/encoder/__init__.py +++ b/gnes/encoder/__init__.py @@ -42,7 +42,8 @@ 'CVAEEncoder': 'image.cvae', 'IncepMixtureEncoder': 'video.incep_mixture', 'VladEncoder': 'numeric.vlad', - 'MfccEncoder': 'audio.mfcc' + 'MfccEncoder': 'audio.mfcc', + 'PoolingEncoder': 'numeric.pooling' } register_all_class(_cls2file_map, 'encoder') diff --git a/gnes/encoder/numeric/pooling.py b/gnes/encoder/numeric/pooling.py index c49297d5..c92c6828 100644 --- a/gnes/encoder/numeric/pooling.py +++ b/gnes/encoder/numeric/pooling.py @@ -3,6 +3,7 @@ import numpy as np from ..base import BaseNumericEncoder +from ...helper import as_numpy_array class PoolingEncoder(BaseNumericEncoder): @@ -12,7 +13,7 @@ def __init__(self, pooling_strategy: str = 'REDUCE_MEAN', super().__init__(*args, **kwargs) valid_poolings = {'REDUCE_MEAN', 'REDUCE_MAX', 'REDUCE_MEAN_MAX'} - valid_backends = {'tensorflow', 'numpy', 'pytorch'} + valid_backends = {'tensorflow', 'numpy', 'pytorch', 'torch'} if pooling_strategy not in valid_poolings: raise ValueError('"pooling_strategy" must be one of %s' % valid_poolings) @@ -21,46 +22,50 @@ def __init__(self, pooling_strategy: str = 'REDUCE_MEAN', self.pooling_strategy = pooling_strategy self.backend = backend - def mul_mask(self, x, m): - if self.backend == 'pytorch': + def post_init(self): + if self.backend in {'pytorch', 'torch'}: import torch - return torch.mul(x, m.unsqueeze(2)) + self.torch = torch elif self.backend == 'tensorflow': import tensorflow as tf - return x * tf.expand_dims(m, axis=-1) + tf.enable_eager_execution() + self.tf = tf + + def mul_mask(self, x, m): + if self.backend in {'pytorch', 'torch'}: + return self.torch.mul(x, m.unsqueeze(2)) + elif self.backend == 'tensorflow': + return x * self.tf.expand_dims(m, axis=-1) elif self.backend == 'numpy': - return 0 + return x * np.expand_dims(m, axis=-1) def minus_mask(self, x, m, offset: int = 1e30): - if self.backend == 'pytorch': + if self.backend in {'pytorch', 'torch'}: return x - (1.0 - m).unsqueeze(2) * offset elif self.backend == 'tensorflow': - import tensorflow as tf - return x - tf.expand_dims(1.0 - m, axis=-1) * offset + return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset elif self.backend == 'numpy': - return 0 + return x - np.expand_dims(1.0 - m, axis=-1) * offset def masked_reduce_mean(self, x, m, jitter: float = 1e-10): - if self.backend == 'pytorch': - import torch - return torch.div(torch.sum(self.mul_mask(x, m), dim=1), - torch.sum(m.unsqueeze(2), dim=1) + jitter) + if self.backend in {'pytorch', 'torch'}: + return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1), + self.torch.sum(m.unsqueeze(2), dim=1) + jitter) elif self.backend == 'tensorflow': - import tensorflow as tf - return tf.reduce_sum(self.mul_mask(x, m), axis=1) / (tf.reduce_sum(m, axis=1, keepdims=True) + jitter) + return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / ( + self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) elif self.backend == 'numpy': return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter) def masked_reduce_max(self, x, m): - if self.backend == 'pytorch': - import torch - return torch.max(self.minus_mask(x, m), 1)[0] + if self.backend in {'pytorch', 'torch'}: + return self.torch.max(self.minus_mask(x, m), 1)[0] elif self.backend == 'tensorflow': - import tensorflow as tf - return tf.reduce_max(self.minus_mask(x, m), axis=1) + return self.tf.reduce_max(self.minus_mask(x, m), axis=1) elif self.backend == 'numpy': return np.max(self.minus_mask(x, m), axis=1) + @as_numpy_array def encode(self, data: Tuple, *args, **kwargs): seq_tensor, mask_tensor = data @@ -69,14 +74,12 @@ def encode(self, data: Tuple, *args, **kwargs): elif self.pooling_strategy == 'REDUCE_MAX': return self.masked_reduce_max(seq_tensor, mask_tensor) elif self.pooling_strategy == 'REDUCE_MEAN_MAX': - if self.backend == 'torch': - import torch - return torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor), - self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1) + if self.backend in {'pytorch', 'torch'}: + return self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor), + self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1) elif self.backend == 'tensorflow': - import tensorflow as tf - return tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor), - self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) + return self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor), + self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) elif self.backend == 'numpy': return np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor), self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) diff --git a/gnes/helper.py b/gnes/helper.py index b52c0f59..2b408f2d 100644 --- a/gnes/helper.py +++ b/gnes/helper.py @@ -481,6 +481,19 @@ def countdown(t: int, logger=None, reason: str = 'I am blocking this thread'): sys.stdout.flush() +def as_numpy_array(func, dtype=np.float32): + @wraps(func) + def arg_wrapper(self, *args, **kwargs): + r = func(self, *args, **kwargs) + r_type = type(r).__name__ + if r_type in {'ndarray', 'EagerTensor', 'Tensor', 'list'}: + return np.array(r, dtype) + else: + raise TypeError('unrecognized type %s: %s' % (r_type, type(r))) + + return arg_wrapper + + def train_required(func): @wraps(func) def arg_wrapper(self, *args, **kwargs): diff --git a/tests/test_pooling_encoder.py b/tests/test_pooling_encoder.py new file mode 100644 index 00000000..14e845b9 --- /dev/null +++ b/tests/test_pooling_encoder.py @@ -0,0 +1,35 @@ +import unittest + +import numpy as np +import torch +from numpy.testing import assert_allclose + +from gnes.encoder.numeric.pooling import PoolingEncoder + + +class TestEncoder(unittest.TestCase): + def setUp(self): + self.seq_data = np.random.random([5, 10]) + self.seq_embed_data = np.random.random([5, 10, 32]) + self.mask_data = np.array(self.seq_data > 0.5, np.float32) + self.data = [ + (torch.tensor(self.seq_embed_data, dtype=torch.float32), torch.tensor(self.mask_data, dtype=torch.float32)), + (self.seq_embed_data, self.mask_data), + (self.seq_embed_data, self.mask_data)] + + def _test_strategy(self, strategy): + pe_to = PoolingEncoder(strategy, 'torch') + pe_tf = PoolingEncoder(strategy, 'tensorflow') + pe_np = PoolingEncoder(strategy, 'numpy') + return [pe.encode(self.data[idx]) for idx, pe in enumerate([pe_to, pe_tf, pe_np])] + + def test_all(self): + for s in {'REDUCE_MEAN', 'REDUCE_MAX', 'REDUCE_MEAN_MAX'}: + with self.subTest(strategy=s): + r = self._test_strategy(s) + for rr in r: + print(type(rr)) + print(rr) + print('---') + assert_allclose(r[0], r[1], rtol=1e-5) + assert_allclose(r[1], r[2], rtol=1e-5)