diff --git a/gnes/preprocessor/image/segmentation.py b/gnes/preprocessor/image/segmentation.py index 87b05683..0dadbc46 100644 --- a/gnes/preprocessor/image/segmentation.py +++ b/gnes/preprocessor/image/segmentation.py @@ -40,15 +40,11 @@ def apply(self, doc: 'gnes_pb2.Document'): image_tensor = image_tensor.cuda() seg_output = self._model([image_tensor]) - chunks = seg_output[0]['boxes'].tolist() + weight = seg_output[0]['scores'].tolist() - if len(chunks) == 0: - c = doc.chunks.add() - c.doc_id = doc.doc_id - c.blob.CopyFrom(array2blob(np.array(original_image.resize((self.target_img_size, - self.target_img_size))))) - c.offset_1d = 1 - c.weight = 1. + length = len(list(filter(lambda x: x >= 0.5, weight))) + chunks = seg_output[0]['boxes'].tolist()[:length] + weight = weight[:length] for ci, ele in enumerate(zip(chunks, weight)): c = doc.chunks.add() @@ -56,6 +52,13 @@ def apply(self, doc: 'gnes_pb2.Document'): c.blob.CopyFrom(array2blob(self._crop_image_reshape(original_image, ele[0]))) c.offset_1d = ci c.weight = ele[1] + + c = doc.chunks.add() + c.doc_id = doc.doc_id + c.blob.CopyFrom(array2blob(np.array(original_image.resize((self.target_img_size, + self.target_img_size))))) + c.offset_1d = len(chunks) + c.weight = 1. else: self.logger.error('bad document: "raw_bytes" is empty!')