From ab6c88ccfe54ba5f96f09510e97b9658c553c1a9 Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Tue, 23 Jul 2019 16:47:33 +0800 Subject: [PATCH] fix(encoder): fix error in cvae encoder --- gnes/encoder/__init__.py | 1 + gnes/encoder/image/cvae.py | 2 +- gnes/encoder/image/cvae_cores/__init__.py | 0 3 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 gnes/encoder/image/cvae_cores/__init__.py diff --git a/gnes/encoder/__init__.py b/gnes/encoder/__init__.py index 1e6ffa0a..e77f2411 100644 --- a/gnes/encoder/__init__.py +++ b/gnes/encoder/__init__.py @@ -40,6 +40,7 @@ 'HashEncoder': 'numeric.hash', 'BasePytorchEncoder': 'image.base', 'TFInceptionEncoder': 'image.inception', + 'CVAEEncoder': 'image.cvae' } register_all_class(_cls2file_map, 'encoder') diff --git a/gnes/encoder/image/cvae.py b/gnes/encoder/image/cvae.py index f297e17d..5489f85a 100644 --- a/gnes/encoder/image/cvae.py +++ b/gnes/encoder/image/cvae.py @@ -40,7 +40,7 @@ def __init__(self, model_dir: str, def post_init(self): import tensorflow as tf - from .cave_cores.model import CVAE + from .cvae_cores.model import CVAE self._model = CVAE(self.latent_dim) self.inputs = tf.placeholder(tf.float32, diff --git a/gnes/encoder/image/cvae_cores/__init__.py b/gnes/encoder/image/cvae_cores/__init__.py new file mode 100644 index 00000000..e69de29b