From 7493af9779093742d0e9fd16df6681366666539c Mon Sep 17 00:00:00 2001 From: Jem Date: Mon, 26 Aug 2019 11:03:22 +0800 Subject: [PATCH 1/2] refactor(preprocessor): add init, change signiture --- gnes/preprocessor/helper.py | 2 +- gnes/preprocessor/io_utils/__init__.py | 0 gnes/preprocessor/io_utils/gif.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 gnes/preprocessor/io_utils/__init__.py diff --git a/gnes/preprocessor/helper.py b/gnes/preprocessor/helper.py index 672fe3a2..09e93580 100644 --- a/gnes/preprocessor/helper.py +++ b/gnes/preprocessor/helper.py @@ -135,7 +135,7 @@ def split_video_frames(buffer_data: bytes, return [np.array(Image.open(io.BytesIO(chunk))) for chunk in chunks] -def get_gif(images, fps=4): +def get_gif(images: 'np.ndarray', fps=10): cmd = ['ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', diff --git a/gnes/preprocessor/io_utils/__init__.py b/gnes/preprocessor/io_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gnes/preprocessor/io_utils/gif.py b/gnes/preprocessor/io_utils/gif.py index 6cb34a32..b841cc9a 100644 --- a/gnes/preprocessor/io_utils/gif.py +++ b/gnes/preprocessor/io_utils/gif.py @@ -47,7 +47,7 @@ def decode_gif(data: bytes, fps: int = -1, def encode_gif( - images: List[np.ndarray], + images: np.ndarray, scale: str, fps: int, pix_fmt: str = 'rgb24'): From 52538276185ffb8c16f10bb70a5f7a1c65aa9076 Mon Sep 17 00:00:00 2001 From: Jem Date: Mon, 26 Aug 2019 13:53:58 +0800 Subject: [PATCH 2/2] refactor(encoder): no for loop in torch encoder now --- gnes/encoder/image/torchvision.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/gnes/encoder/image/torchvision.py b/gnes/encoder/image/torchvision.py index 56fc679a..b6596a89 100644 --- a/gnes/encoder/image/torchvision.py +++ b/gnes/encoder/image/torchvision.py @@ -109,14 +109,9 @@ def _encode(_, img: List['np.ndarray']): if self._use_cuda: img_tensor = img_tensor.cuda() - result_npy = [] - for t in img_tensor: - t = torch.unsqueeze(t, 0) - encodes = self._model(t) - encodes = torch.squeeze(encodes, 0) - result_npy.append(encodes.data.cpu().numpy()) - - output = np.array(result_npy, dtype=np.float32) + encodes = self._model(img_tensor) + + output = np.array(encodes.data.cpu().numpy(), dtype=np.float32) return output output = _encode(self, img)