diff --git a/gnes/encoder/image/base.py b/gnes/encoder/image/base.py index cd3f8c58..90691dbd 100644 --- a/gnes/encoder/image/base.py +++ b/gnes/encoder/image/base.py @@ -27,15 +27,16 @@ class BasePytorchEncoder(BaseImageEncoder): def __init__(self, model_name: str, layers: List[str], model_dir: str, - batch_size: int = 64, *args, **kwargs): + batch_size: int = 64, + use_cuda: bool = False, + *args, **kwargs): super().__init__(*args, **kwargs) self.batch_size = batch_size self.model_dir = model_dir self.model_name = model_name self.layers = layers - self.is_trained = True - self._use_cuda = False + self._use_cuda = use_cuda def post_init(self): import torch diff --git a/gnes/encoder/image/inception.py b/gnes/encoder/image/inception.py index 27712987..11dce759 100644 --- a/gnes/encoder/image/inception.py +++ b/gnes/encoder/image/inception.py @@ -26,14 +26,14 @@ class TFInceptionEncoder(BaseImageEncoder): def __init__(self, model_dir: str, batch_size: int = 64, select_layer: str = 'PreLogitsFlatten', - use_gpu: bool = True, + use_cuda: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self.model_dir = model_dir self.batch_size = batch_size self.select_layer = select_layer - self.use_gpu = use_gpu + self._use_cuda = use_cuda self.inception_size_x = 299 self.inception_size_y = 299 @@ -54,7 +54,7 @@ def post_init(self): dropout_keep_prob=1.0) config = tf.ConfigProto(log_device_placement=False) - if self.use_gpu: + if self._use_cuda: config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.saver = tf.train.Saver()