Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(fasterrcnn): fix bug for gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin committed Jul 17, 2019
1 parent f2667bb commit a132991
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion gnes/preprocessor/image/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, model_name: str,
self.model_name = model_name
self.model_dir = model_dir
self.target_img_size = target_img_size
self.model_name = model_name
self._use_cuda = _use_cuda

def post_init(self):
Expand All @@ -37,6 +36,8 @@ def apply(self, doc: 'gnes_pb2.Document'):
if doc.raw_bytes:
original_image = Image.open(io.BytesIO(doc.raw_bytes))
image_tensor = self._torch_transform(original_image)
if self._use_cuda:
image_tensor = image_tensor.cuda()

seg_output = self._model([image_tensor])
chunks = seg_output[0]['boxes'].tolist()
Expand Down

0 comments on commit a132991

Please sign in to comment.