Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(encoder): add convolution variational autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Jul 17, 2019
1 parent 25f0380 commit abb0841
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
69 changes: 69 additions & 0 deletions gnes/encoder/image/cvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import numpy as np
from gnes.helper import batch_iterator
from ..base import BaseImageEncoder
from PIL import Image


class CVAEEncoder(BaseImageEncoder):

def __init__(self, model_dir: str,
latent_dim: int = 300,
batch_size: int = 64,
select_method: str = 'MEAN',
use_gpu: bool = True,
*args, **kwargs):
super().__init__(*args, **kwargs)

self.model_dir = model_dir
self.latent_dim = latent_dim
self.batch_size = batch_size
self.select_method = select_method
self.use_gpu = use_gpu

def post_init(self):
import tensorflow as tf
from .cave_cores.model import CVAE

self._model = CVAE(self.latent_dim)
self.inputs = tf.placeholder(tf.float32,
(None, 120, 120, 3))

self.mean, self.var = self._model.encode(self.inputs)

config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((120, 120)),
dtype=np.float32)/255) for im in img]
for _im in batch_iterator(img, self.batch_size):
_mean, _var = self.sess.run((self.mean, self.var),
feed_dict={self.inputs: _im})
if self.select_method == 'MEAN':
ret.append(_mean)
elif self.select_method == 'VAR':
ret.append(_var)
elif self.select_method == 'MEAN_VAR':
ret.append(np.concatenate([_mean, _var]), axis=1)
return np.concatenate(ret, axis=0).astype(np.float32)
114 changes: 114 additions & 0 deletions gnes/encoder/image/cvae_cores/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np


class CVAE(tf.keras.Model):
def __init__(self, latent_dim):
super(CVAE, self).__init__()
self.latent_dim = latent_dim
self.inference_net = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(120, 120, 3)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim + latent_dim),
]
)

self.generative_net = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=15*15*32,
activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(15, 15, 32)),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=3, kernel_size=3, strides=(1, 1), padding="SAME"),
]
)

def sample(self, eps=None):
if eps is None:
eps = tf.random_normal(shape=(100, self.latent_dim))
return self.decode(eps, apply_sigmoid=True)

def encode(self, x):
mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
return mean, logvar

def reparameterize(self, mean, logvar):
eps = tf.random_normal(shape=tf.shape(mean))
return eps * tf.exp(logvar * .5) + mean

def decode(self, z, apply_sigmoid=False):
logits = self.generative_net(z)
if apply_sigmoid:
probs = tf.sigmoid(logits)
return probs

return logits

def compute_loss(self, x):
mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
x_logit = self.decode(z)

cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit,
labels=x)
logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
logpz = CVAE.log_normal_pdf(z, 0., 0.)
logqz_x = CVAE.log_normal_pdf(z, mean, logvar)

return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@staticmethod
def log_normal_pdf(sample, mean, logvar, raxis=1):
log2pi = tf.math.log(2. * np.pi)
return tf.reduce_sum(
-.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
axis=raxis)

0 comments on commit abb0841

Please sign in to comment.