Skip to content

Commit

Permalink
Fix delimiters for masks
Browse files Browse the repository at this point in the history
  • Loading branch information
pshivraj committed Feb 15, 2019
1 parent d73cc04 commit 0a6e699
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mrcnn/scripts/pre_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def mask_to_h5(train_path):
label_array = []
for mask_file in next(os.walk(path + MASK_PATH))[2]:
if 'png' in mask_file:
class_id = int(mask_file.split('$')[1][:-4])
class_id = int(mask_file.split('__')[1][:-4])
label_array.append(class_id)
mask_ = cv2.imread(path + MASK_PATH + mask_file, 0)
i += 1
Expand Down
10 changes: 3 additions & 7 deletions mrcnn/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def load_shapes(self, id_list, train_path):
# Add classes
for val in class_mapping.keys():
self.add_class('clomask', val, class_mapping[val])
# self.add_class('clomask', 2, "boxes")
# self.add_class('clomask', 3, "bags")
self.train_path = train_path
# Add images
for i, id_ in enumerate(id_list):
Expand Down Expand Up @@ -78,19 +76,18 @@ def load_mask(self, image_id):
else:
path = self.train_path + info['img_name']
mask = []
label_array = []
class_ids = []
for mask_file in next(os.walk(path + MASK_PATH))[2]:
if 'png' in mask_file:
# these lines have been commented out due to invalid test data file name
# class_id = int(mask_file.split('$')[1][:-4])
# label_array.append(class_id)
class_id = int(mask_file.split('__')[1][:-4])
class_ids.append(class_id)
mask_ = cv2.imread(path + MASK_PATH + mask_file, 0)
mask_ = np.where(mask_ > 128, 1, 0)
# Add mask only if its area is larger than one pixel
if np.sum(mask_) >= 1:
mask.append(np.squeeze(mask_))
mask = np.stack(mask, axis=-1)
class_ids = np.ones(mask.shape[2])
return mask.astype(np.uint8), class_ids.astype(np.int8)


Expand Down Expand Up @@ -121,7 +118,6 @@ def prepare_dataset(self):
train_data = ClomaskDataset()
train_data.load_shapes(train_list, TRAIN_PATH)
train_data.prepare()
print(train_data.class_names)

# initialize validation dataset
validation_data = ClomaskDataset()
Expand Down

0 comments on commit 0a6e699

Please sign in to comment.