Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #152 from gnes-ai/torch_encoder
Browse files Browse the repository at this point in the history
refactor(encoder): no for loop in torch encoder now
  • Loading branch information
mergify[bot] authored Aug 26, 2019
2 parents 37ca2f6 + 5253827 commit 8ff885a
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions gnes/encoder/image/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,9 @@ def _encode(_, img: List['np.ndarray']):
if self._use_cuda:
img_tensor = img_tensor.cuda()

result_npy = []
for t in img_tensor:
t = torch.unsqueeze(t, 0)
encodes = self._model(t)
encodes = torch.squeeze(encodes, 0)
result_npy.append(encodes.data.cpu().numpy())

output = np.array(result_npy, dtype=np.float32)
encodes = self._model(img_tensor)

output = np.array(encodes.data.cpu().numpy(), dtype=np.float32)
return output

output = _encode(self, img)
Expand Down

0 comments on commit 8ff885a

Please sign in to comment.