Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(encoder): separate pooling as an indep. encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 26, 2019
1 parent 3956100 commit b4444cc
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions gnes/encoder/numeric/pooling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -27,26 +28,30 @@ def post_init(self):
import torch
self.torch = torch
elif self.backend == 'tensorflow':
os.environ['CUDA_VISIBLE_DEVICES'] = '0' if self.on_gpu else '-1'
import tensorflow as tf
try:
tf.enable_eager_execution()
except ValueError:
pass
self._tf_graph = tf.Graph()
config = tf.ConfigProto(device_count={'GPU': 1 if self.on_gpu else 0})
config.gpu_options.allow_growth = True
config.log_device_placement = False
self._sess = tf.Session(graph=self._tf_graph, config=config)
self.tf = tf

def mul_mask(self, x, m):
if self.backend in {'pytorch', 'torch'}:
return self.torch.mul(x, m.unsqueeze(2))
elif self.backend == 'tensorflow':
return x * self.tf.expand_dims(m, axis=-1)
with self._tf_graph.as_default():
return x * self.tf.expand_dims(m, axis=-1)
elif self.backend == 'numpy':
return x * np.expand_dims(m, axis=-1)

def minus_mask(self, x, m, offset: int = 1e30):
if self.backend in {'pytorch', 'torch'}:
return x - (1.0 - m).unsqueeze(2) * offset
elif self.backend == 'tensorflow':
return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset
with self._tf_graph.as_default():
return x - self.tf.expand_dims(1.0 - m, axis=-1) * offset
elif self.backend == 'numpy':
return x - np.expand_dims(1.0 - m, axis=-1) * offset

Expand All @@ -55,16 +60,18 @@ def masked_reduce_mean(self, x, m, jitter: float = 1e-10):
return self.torch.div(self.torch.sum(self.mul_mask(x, m), dim=1),
self.torch.sum(m.unsqueeze(2), dim=1) + jitter)
elif self.backend == 'tensorflow':
return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / (
self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
with self._tf_graph.as_default():
return self.tf.reduce_sum(self.mul_mask(x, m), axis=1) / (
self.tf.reduce_sum(m, axis=1, keepdims=True) + jitter)
elif self.backend == 'numpy':
return np.sum(self.mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + jitter)

def masked_reduce_max(self, x, m):
if self.backend in {'pytorch', 'torch'}:
return self.torch.max(self.minus_mask(x, m), 1)[0]
elif self.backend == 'tensorflow':
return self.tf.reduce_max(self.minus_mask(x, m), axis=1)
with self._tf_graph.as_default():
return self.tf.reduce_max(self.minus_mask(x, m), axis=1)
elif self.backend == 'numpy':
return np.max(self.minus_mask(x, m), axis=1)

Expand All @@ -73,16 +80,21 @@ def encode(self, data: Tuple, *args, **kwargs):
seq_tensor, mask_tensor = data

if self.pooling_strategy == 'REDUCE_MEAN':
return self.masked_reduce_mean(seq_tensor, mask_tensor)
r = self.masked_reduce_mean(seq_tensor, mask_tensor)
elif self.pooling_strategy == 'REDUCE_MAX':
return self.masked_reduce_max(seq_tensor, mask_tensor)
r = self.masked_reduce_max(seq_tensor, mask_tensor)
elif self.pooling_strategy == 'REDUCE_MEAN_MAX':
if self.backend in {'pytorch', 'torch'}:
return self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
r = self.torch.cat((self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)), dim=1)
elif self.backend == 'tensorflow':
return self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
with self._tf_graph.as_default():
r = self.tf.concat([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
elif self.backend == 'numpy':
return np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)
r = np.concatenate([self.masked_reduce_mean(seq_tensor, mask_tensor),
self.masked_reduce_max(seq_tensor, mask_tensor)], axis=1)

if self.backend == 'tensorflow':
r = self._sess.run(r)
return r

0 comments on commit b4444cc

Please sign in to comment.