diff --git a/gnes/encoder/image/torchvision.py b/gnes/encoder/image/torchvision.py index 56fc679a..b6596a89 100644 --- a/gnes/encoder/image/torchvision.py +++ b/gnes/encoder/image/torchvision.py @@ -109,14 +109,9 @@ def _encode(_, img: List['np.ndarray']): if self._use_cuda: img_tensor = img_tensor.cuda() - result_npy = [] - for t in img_tensor: - t = torch.unsqueeze(t, 0) - encodes = self._model(t) - encodes = torch.squeeze(encodes, 0) - result_npy.append(encodes.data.cpu().numpy()) - - output = np.array(result_npy, dtype=np.float32) + encodes = self._model(img_tensor) + + output = np.array(encodes.data.cpu().numpy(), dtype=np.float32) return output output = _encode(self, img) diff --git a/gnes/preprocessor/helper.py b/gnes/preprocessor/helper.py index 672fe3a2..09e93580 100644 --- a/gnes/preprocessor/helper.py +++ b/gnes/preprocessor/helper.py @@ -135,7 +135,7 @@ def split_video_frames(buffer_data: bytes, return [np.array(Image.open(io.BytesIO(chunk))) for chunk in chunks] -def get_gif(images, fps=4): +def get_gif(images: 'np.ndarray', fps=10): cmd = ['ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', diff --git a/gnes/preprocessor/io_utils/__init__.py b/gnes/preprocessor/io_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gnes/preprocessor/io_utils/gif.py b/gnes/preprocessor/io_utils/gif.py index 6cb34a32..b841cc9a 100644 --- a/gnes/preprocessor/io_utils/gif.py +++ b/gnes/preprocessor/io_utils/gif.py @@ -47,7 +47,7 @@ def decode_gif(data: bytes, fps: int = -1, def encode_gif( - images: List[np.ndarray], + images: np.ndarray, scale: str, fps: int, pix_fmt: str = 'rgb24'):