From a13299132b776ccbb1826bf0f3a361e431a0f2ca Mon Sep 17 00:00:00 2001 From: Jem Date: Wed, 17 Jul 2019 22:35:58 +0800 Subject: [PATCH] fix(fasterrcnn): fix bug for gpu --- gnes/preprocessor/image/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gnes/preprocessor/image/segmentation.py b/gnes/preprocessor/image/segmentation.py index f1b1ddde..b514bede 100644 --- a/gnes/preprocessor/image/segmentation.py +++ b/gnes/preprocessor/image/segmentation.py @@ -17,7 +17,6 @@ def __init__(self, model_name: str, self.model_name = model_name self.model_dir = model_dir self.target_img_size = target_img_size - self.model_name = model_name self._use_cuda = _use_cuda def post_init(self): @@ -37,6 +36,8 @@ def apply(self, doc: 'gnes_pb2.Document'): if doc.raw_bytes: original_image = Image.open(io.BytesIO(doc.raw_bytes)) image_tensor = self._torch_transform(original_image) + if self._use_cuda: + image_tensor = image_tensor.cuda() seg_output = self._model([image_tensor]) chunks = seg_output[0]['boxes'].tolist()