Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(batching): enable to process three dimension output in batching
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin committed Aug 12, 2019
1 parent b0f22d0 commit 64163cb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
22 changes: 10 additions & 12 deletions gnes/encoder/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def forward(self, x):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
import torch
self._model.eval()
Expand All @@ -87,13 +86,21 @@ def _padding(img: List['np.ndarray']):
if im.shape[0] < max_lenth else im for im in img]
return img, max_lenth

@batching
# for video
if len(img[0].shape) == 4:
img, max_lenth = _padding(img)
# for image
else:
max_lenth = -1

@batching(chunk_dim=max_lenth)
def _encode(_, img: List['np.ndarray']):
import copy

if len(img[0].shape) == 4:
img_ = copy.deepcopy(img)
img_ = np.concatenate((list(img_[i] for i in range(len(img_)))), axis=0)

img_for_torch = np.array(img_, dtype=np.float32).transpose(0, 3, 1, 2)
else:
img_for_torch = np.array(img, dtype=np.float32).transpose(0, 3, 1, 2)
Expand All @@ -110,17 +117,8 @@ def _encode(_, img: List['np.ndarray']):
result_npy.append(encodes.data.cpu().numpy())

output = np.array(result_npy, dtype=np.float32)

if len(img[0].shape) == 4:
output = output.reshape((len(img), max_lenth, -1))
return output

# for video
if len(img[0].shape) == 4:
padding_image, max_lenth = _padding(img)
output = _encode(self, padding_image)
# for image
else:
output = _encode(self, img)
output = _encode(self, img)

return output
5 changes: 4 additions & 1 deletion gnes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def pooling_torch(data_tensor, mask_tensor, pooling_strategy):

def batching(func: Callable[[Any], np.ndarray] = None, *,
batch_size: Union[int, Callable] = None, num_batch=None,
iter_axis: int = 0, concat_axis: int = 0):
iter_axis: int = 0, concat_axis: int = 0, chunk_dim=-1):
def _batching(func):
@wraps(func)
def arg_wrapper(self, data, label=None, *args, **kwargs):
Expand Down Expand Up @@ -418,6 +418,9 @@ def arg_wrapper(self, data, label=None, *args, **kwargs):
if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)

if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))

if len(final_result):
return final_result

Expand Down

0 comments on commit 64163cb

Please sign in to comment.