diff --git a/gnes/encoder/numeric/pooling.py b/gnes/encoder/numeric/pooling.py index e1e263f3..bcfe1ecf 100644 --- a/gnes/encoder/numeric/pooling.py +++ b/gnes/encoder/numeric/pooling.py @@ -1,3 +1,4 @@ +import os from typing import Tuple import numpy as np @@ -27,18 +28,21 @@ def post_init(self): import torch self.torch = torch elif self.backend == 'tensorflow': + os.environ['CUDA_VISIBLE_DEVICES'] = '0' if self.on_gpu else '-1' import tensorflow as tf - try: - tf.enable_eager_execution() - except ValueError: - pass + self._tf_graph = tf.Graph() + config = tf.ConfigProto(device_count={'GPU': 1 if self.on_gpu else 0}) + config.gpu_options.allow_growth = True + config.log_device_placement = False + self._sess = tf.Session(graph=self._tf_graph, config=config) self.tf = tf def mul_mask(self, x, m): if self.backend in {'pytorch', 'torch'}: return self.torch.mul(x, m.unsqueeze(2)) elif self.backend == 'tensorflow': - return x * self.tf.expand_dims(m, axis=-1) + with self._tf_graph.as_default(): + return x * self.tf.expand_dims(m, axis=-1) elif self.backend == 'numpy': return x * np.expand_dims(m, axis=-1) @@ -46,7 +50,8 @@ def minus_mask(self, x, m, offset: int = 1e30): if self.backend in {'pytorch', 'torch'}: return x - (1.0 - m).unsqueeze(2) * offset elif self.backend == 'tensorflow': - return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset + with self._tf_graph.as_default(): + return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset elif self.backend == 'numpy': return x - np.expand_dims(1.0 - m, axis=-1) * offset @@ -55,8 +60,9 @@ def masked_reduce_mean(self, x, m, jitter: float = 1e-10): return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1), self.torch.sum(m.unsqueeze(2), dim=1) + jitter) elif self.backend == 'tensorflow': - return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / ( - self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) + with self._tf_graph.as_default(): + return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / ( + self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter) elif self.backend == 'numpy': return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter) @@ -64,7 +70,8 @@ def masked_reduce_max(self, x, m): if self.backend in {'pytorch', 'torch'}: return self.torch.max(self.minus_mask(x, m), 1)[0] elif self.backend == 'tensorflow': - return self.tf.reduce_max(self.minus_mask(x, m), axis=1) + with self._tf_graph.as_default(): + return self.tf.reduce_max(self.minus_mask(x, m), axis=1) elif self.backend == 'numpy': return np.max(self.minus_mask(x, m), axis=1) @@ -73,16 +80,21 @@ def encode(self, data: Tuple, *args, **kwargs): seq_tensor, mask_tensor = data if self.pooling_strategy == 'REDUCE_MEAN': - return self.masked_reduce_mean(seq_tensor, mask_tensor) + r = self.masked_reduce_mean(seq_tensor, mask_tensor) elif self.pooling_strategy == 'REDUCE_MAX': - return self.masked_reduce_max(seq_tensor, mask_tensor) + r = self.masked_reduce_max(seq_tensor, mask_tensor) elif self.pooling_strategy == 'REDUCE_MEAN_MAX': if self.backend in {'pytorch', 'torch'}: - return self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor), - self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1) + r = self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor), + self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1) elif self.backend == 'tensorflow': - return self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor), - self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) + with self._tf_graph.as_default(): + r = self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor), + self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) elif self.backend == 'numpy': - return np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor), - self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) + r = np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor), + self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1) + + if self.backend == 'tensorflow': + r = self._sess.run(r) + return r