diff --git a/gnes/encoder/image/cvae.py b/gnes/encoder/image/cvae.py new file mode 100644 index 00000000..e3b0e711 --- /dev/null +++ b/gnes/encoder/image/cvae.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import numpy as np +from gnes.helper import batch_iterator +from ..base import BaseImageEncoder +from PIL import Image + + +class CVAEEncoder(BaseImageEncoder): + + def __init__(self, model_dir: str, + latent_dim: int = 300, + batch_size: int = 64, + select_method: str = 'MEAN', + use_gpu: bool = True, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.model_dir = model_dir + self.latent_dim = latent_dim + self.batch_size = batch_size + self.select_method = select_method + self.use_gpu = use_gpu + + def post_init(self): + import tensorflow as tf + from .cave_cores.model import CVAE + + self._model = CVAE(self.latent_dim) + self.inputs = tf.placeholder(tf.float32, + (None, 120, 120, 3)) + + self.mean, self.var = self._model.encode(self.inputs) + + config = tf.ConfigProto(log_device_placement=False) + if self.use_gpu: + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + self.saver = tf.train.Saver() + self.saver.restore(self.sess, self.model_dir) + + def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: + ret = [] + img = [(np.array(Image.fromarray(im).resize((120, 120)), + dtype=np.float32)/255) for im in img] + for _im in batch_iterator(img, self.batch_size): + _mean, _var = self.sess.run((self.mean, self.var), + feed_dict={self.inputs: _im}) + if self.select_method == 'MEAN': + ret.append(_mean) + elif self.select_method == 'VAR': + ret.append(_var) + elif self.select_method == 'MEAN_VAR': + ret.append(np.concatenate([_mean, _var]), axis=1) + return np.concatenate(ret, axis=0).astype(np.float32) diff --git a/gnes/encoder/image/cvae_cores/model.py b/gnes/encoder/image/cvae_cores/model.py new file mode 100644 index 00000000..2df695d4 --- /dev/null +++ b/gnes/encoder/image/cvae_cores/model.py @@ -0,0 +1,114 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np + + +class CVAE(tf.keras.Model): + def __init__(self, latent_dim): + super(CVAE, self).__init__() + self.latent_dim = latent_dim + self.inference_net = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=(120, 120, 3)), + tf.keras.layers.Conv2D( + filters=32, kernel_size=3, strides=(2, 2), + padding='SAME', + activation='relu'), + tf.keras.layers.Conv2D( + filters=32, kernel_size=3, strides=(2, 2), + padding='SAME', + activation='relu'), + tf.keras.layers.Conv2D( + filters=32, kernel_size=3, strides=(2, 2), + padding='SAME', + activation='relu'), + tf.keras.layers.Flatten(), + # No activation + tf.keras.layers.Dense(latent_dim + latent_dim), + ] + ) + + self.generative_net = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=(latent_dim,)), + tf.keras.layers.Dense(units=15*15*32, + activation=tf.nn.relu), + tf.keras.layers.Reshape(target_shape=(15, 15, 32)), + tf.keras.layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=(2, 2), + padding="SAME", + activation='relu'), + tf.keras.layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=(2, 2), + padding="SAME", + activation='relu'), + tf.keras.layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=(2, 2), + padding="SAME", + activation='relu'), + # No activation + tf.keras.layers.Conv2DTranspose( + filters=3, kernel_size=3, strides=(1, 1), padding="SAME"), + ] + ) + + def sample(self, eps=None): + if eps is None: + eps = tf.random_normal(shape=(100, self.latent_dim)) + return self.decode(eps, apply_sigmoid=True) + + def encode(self, x): + mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1) + return mean, logvar + + def reparameterize(self, mean, logvar): + eps = tf.random_normal(shape=tf.shape(mean)) + return eps * tf.exp(logvar * .5) + mean + + def decode(self, z, apply_sigmoid=False): + logits = self.generative_net(z) + if apply_sigmoid: + probs = tf.sigmoid(logits) + return probs + + return logits + + def compute_loss(self, x): + mean, logvar = self.encode(x) + z = self.reparameterize(mean, logvar) + x_logit = self.decode(z) + + cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, + labels=x) + logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3]) + logpz = CVAE.log_normal_pdf(z, 0., 0.) + logqz_x = CVAE.log_normal_pdf(z, mean, logvar) + + return -tf.reduce_mean(logpx_z + logpz - logqz_x) + + @staticmethod + def log_normal_pdf(sample, mean, logvar, raxis=1): + log2pi = tf.math.log(2. * np.pi) + return tf.reduce_sum( + -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), + axis=raxis)