Skip to content

Commit

Permalink
fix sam_mixin.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed May 24, 2023
1 parent 8f2d67e commit bb707a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ cache

# db
*.db
checkpoints/*
# pth
*.pth
17 changes: 8 additions & 9 deletions openadapt/strategies/sam_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,22 @@ def _initialize_model(self, model_name, checkpoint_dir_path):

def get_autosegmented_screenshot(self, screenshot: Screenshot) -> Screenshot:
masks = self.sam_mask_generator.generate(screenshot.array)
segmented_image = apply_masks(screenshot.image, masks)
segmented_image = apply_masks(masks)

# Create a new Screenshot object with the segmented image
segmented_screenshot = Screenshot()
segmented_screenshot.sct_img = pil_to_sct(segmented_image)

return segmented_screenshot

def apply_masks(self, image, masks):
mask_img = np.zeros_like(image)

for ann in masks:
def apply_masks(self, anns):
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.random.random(3)
mask_img[m] = color_mask

segmented_image = Image.fromarray(mask_img)
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
segmented_image = Image.fromarray(img)
return segmented_image

def pil_to_sct(self, image):
Expand Down

0 comments on commit bb707a6

Please sign in to comment.