Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix tf scope error in cvae encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Jul 23, 2019
1 parent ab6c88c commit eb48779
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions gnes/encoder/image/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ def __init__(self, model_dir: str,
def post_init(self):
import tensorflow as tf
from .cvae_cores.model import CVAE
g = tf.Graph()
with g.as_default():
self._model = CVAE(self.latent_dim)
self.inputs = tf.placeholder(tf.float32,
(None, 120, 120, 3))

self._model = CVAE(self.latent_dim)
self.inputs = tf.placeholder(tf.float32,
(None, 120, 120, 3))
self.mean, self.var = self._model.encode(self.inputs)

self.mean, self.var = self._model.encode(self.inputs)

config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)
config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
Expand Down

0 comments on commit eb48779

Please sign in to comment.