diff --git a/gnes/encoder/image/inception.py b/gnes/encoder/image/inception.py index a572b945..87bdeee7 100644 --- a/gnes/encoder/image/inception.py +++ b/gnes/encoder/image/inception.py @@ -19,7 +19,7 @@ from PIL import Image from ..base import BaseImageEncoder -from ...helper import batching, batch_iterator +from ...helper import batching, batch_iterator, get_first_available_gpu class TFInceptionEncoder(BaseImageEncoder): @@ -42,7 +42,8 @@ def post_init(self): import tensorflow as tf from .inception_cores.inception_v4 import inception_v4 from .inception_cores.inception_utils import inception_arg_scope - + import os + os.environ['CUDA_VISIBLE_DEVICES'] = get_first_available_gpu() g = tf.Graph() with g.as_default(): arg_scope = inception_arg_scope() diff --git a/gnes/helper.py b/gnes/helper.py index a14e435c..9758aa02 100644 --- a/gnes/helper.py +++ b/gnes/helper.py @@ -55,9 +55,9 @@ def get_first_available_gpu(): return r[0] raise ValueError except ImportError: - return 0 + return -1 except ValueError: - return 0 + return -1 class FileLock: