diff --git a/gnes/encoder/image/base.py b/gnes/encoder/image/base.py index 5d1d05a1..0b4612bb 100644 --- a/gnes/encoder/image/base.py +++ b/gnes/encoder/image/base.py @@ -74,7 +74,6 @@ def forward(self, x): self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self._model = self._model.to(self._device) - @batching def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: import torch self._model.eval() @@ -87,13 +86,21 @@ def _padding(img: List['np.ndarray']): if im.shape[0] < max_lenth else im for im in img] return img, max_lenth - @batching + # for video + if len(img[0].shape) == 4: + img, max_lenth = _padding(img) + # for image + else: + max_lenth = -1 + + @batching(chunk_dim=max_lenth) def _encode(_, img: List['np.ndarray']): import copy if len(img[0].shape) == 4: img_ = copy.deepcopy(img) img_ = np.concatenate((list(img_[i] for i in range(len(img_)))), axis=0) + img_for_torch = np.array(img_, dtype=np.float32).transpose(0, 3, 1, 2) else: img_for_torch = np.array(img, dtype=np.float32).transpose(0, 3, 1, 2) @@ -110,17 +117,8 @@ def _encode(_, img: List['np.ndarray']): result_npy.append(encodes.data.cpu().numpy()) output = np.array(result_npy, dtype=np.float32) - - if len(img[0].shape) == 4: - output = output.reshape((len(img), max_lenth, -1)) return output - # for video - if len(img[0].shape) == 4: - padding_image, max_lenth = _padding(img) - output = _encode(self, padding_image) - # for image - else: - output = _encode(self, img) + output = _encode(self, img) return output diff --git a/gnes/helper.py b/gnes/helper.py index ae7f7444..0b202611 100644 --- a/gnes/helper.py +++ b/gnes/helper.py @@ -375,7 +375,7 @@ def pooling_torch(data_tensor, mask_tensor, pooling_strategy): def batching(func: Callable[[Any], np.ndarray] = None, *, batch_size: Union[int, Callable] = None, num_batch=None, - iter_axis: int = 0, concat_axis: int = 0): + iter_axis: int = 0, concat_axis: int = 0, chunk_dim=-1): def _batching(func): @wraps(func) def arg_wrapper(self, data, label=None, *args, **kwargs): @@ -418,6 +418,9 @@ def arg_wrapper(self, data, label=None, *args, **kwargs): if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray): final_result = np.concatenate(final_result, concat_axis) + if chunk_dim != -1: + final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1])) + if len(final_result): return final_result