Skip to content

Commit

Permalink
add SHOW_PLOTS
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed May 29, 2023
1 parent f221598 commit 8ea9773
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions openadapt/strategies/mixins/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MyReplayStrategy(SAMReplayStrategyMixin):
MODEL_NAME = "default"
CHECKPOINT_DIR_PATH = "./checkpoints"
RESIZE_RATIO = 0.1

SHOW_PLOTS = True
class SAMReplayStrategyMixin(BaseReplayStrategy):
def __init__(
self,
Expand All @@ -57,12 +57,13 @@ def _initialize_model(self, model_name, checkpoint_dir_path):
urllib.request.urlretrieve(checkpoint_url, checkpoint_file_path)
return sam_model_registry[model_name](checkpoint=checkpoint_file_path)

def get_screenshot_bbox(self, screenshot: Screenshot) -> str:
def get_screenshot_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) -> str:
"""
Get the bounding boxes of objects in a screenshot image in XYWH format.
Args:
screenshot (Screenshot): The screenshot object containing the image.
show_plots (bool): Flag indicating whether to display the plots or not. Defaults to SHOW_PLOTS.
Returns:
str: A string representation of a list containing the bounding boxes of objects.
Expand All @@ -72,19 +73,26 @@ def get_screenshot_bbox(self, screenshot: Screenshot) -> str:
array_resized = np.array(image_resized)
masks = self.sam_mask_generator.generate(array_resized)
bbox_list = []
plt.figure(figsize=(10, 10))
plt.imshow(array_resized)
for mask in masks:
bbox_list.append(mask["bbox"])
show_mask(mask["segmentation"], plt.gca())
show_box(mask["bbox"], plt.gca())
if SHOW_PLOTS :
plt.figure(figsize=(20,20))
plt.imshow(array_resized)
show_anns(masks)
plt.axis('off')
plt.show()
return str(bbox_list)

def get_click_event_bbox(self, screenshot: Screenshot) -> str:
def get_click_event_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) -> str:
"""
Get the bounding box of the clicked object in a screenshot image in XYWH format.
Args:
screenshot (Screenshot): The screenshot object containing the image.
show_plots (bool): Flag indicating whether to display the plots or not. Defaults to SHOW_PLOTS.
Returns:
str: A string representation of a list containing the bounding box of the clicked object.
None: If the screenshot does not represent a click event with the mouse pressed.
Expand Down Expand Up @@ -117,13 +125,14 @@ def get_click_event_bbox(self, screenshot: Screenshot) -> str:
w = np.max(cols)
h = np.max(rows)
input_box = [x0, y0, w, h]
plt.figure(figsize=(10, 10))
plt.imshow(array_resized)
show_mask(best_mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()
if SHOW_PLOTS :
plt.figure(figsize=(10, 10))
plt.imshow(array_resized)
show_mask(best_mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()
return [x0, y0, w - x0, h - y0]
return []

Expand Down Expand Up @@ -183,3 +192,18 @@ def show_box(box, ax):
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
)

def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)

0 comments on commit 8ea9773

Please sign in to comment.