Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #17 from gnes-ai/image_encoder_hotfix
Browse files Browse the repository at this point in the history
fix(image encoder): enable batching and define use_cuda via args
  • Loading branch information
numb3r3 authored Jul 17, 2019
2 parents 25f0380 + 06e1a29 commit 94f1496
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.git/
.pyre/
.idea/
docker-push.sh
Expand Down
7 changes: 3 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@ COPY setup.py ./setup.py

RUN python -c "import distutils.core;s=distutils.core.run_setup('setup.py').install_requires;f=open('requirements_tmp.txt', 'w');[f.write(v+'\n') for v in s];f.close()" && cat requirements_tmp.txt

RUN pip install -r requirements_tmp.txt
RUN pip --no-cache-dir install -r requirements_tmp.txt

FROM dependency as base

ADD . ./

RUN pip install .[all] \
&& rm -rf /tmp/*

WORKDIR /
RUN pip --no-cache-dir install .[all] \
&& rm -rf /tmp/*

ENTRYPOINT ["gnes"]
9 changes: 6 additions & 3 deletions gnes/encoder/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@
import numpy as np

from ..base import BaseImageEncoder
from ...helper import batching


class BasePytorchEncoder(BaseImageEncoder):

def __init__(self, model_name: str,
layers: List[str],
model_dir: str,
batch_size: int = 64, *args, **kwargs):
batch_size: int = 64,
use_cuda: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)

self.batch_size = batch_size
self.model_dir = model_dir
self.model_name = model_name
self.layers = layers
self.is_trained = True
self._use_cuda = False
self._use_cuda = use_cuda

def post_init(self):
import torch
Expand Down Expand Up @@ -72,6 +74,7 @@ def forward(self, x):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
import torch
self._model.eval()
Expand Down
8 changes: 5 additions & 3 deletions gnes/encoder/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from gnes.helper import batch_iterator
from ..base import BaseImageEncoder
from ...helper import batching
from PIL import Image


Expand All @@ -25,14 +26,14 @@ class TFInceptionEncoder(BaseImageEncoder):
def __init__(self, model_dir: str,
batch_size: int = 64,
select_layer: str = 'PreLogitsFlatten',
use_gpu: bool = True,
use_cuda: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)

self.model_dir = model_dir
self.batch_size = batch_size
self.select_layer = select_layer
self.use_gpu = use_gpu
self._use_cuda = use_cuda
self.inception_size_x = 299
self.inception_size_y = 299

Expand All @@ -53,12 +54,13 @@ def post_init(self):
dropout_keep_prob=1.0)

config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
if self._use_cuda:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,
Expand Down

0 comments on commit 94f1496

Please sign in to comment.