Skip to content

Commit

Permalink
black sam.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed Jun 5, 2023
1 parent ad39a80 commit abcd31d
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions openadapt/strategies/mixins/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class MyReplayStrategy(SAMReplayStrategyMixin):
CHECKPOINT_DIR_PATH = "./checkpoints"
RESIZE_RATIO = 0.1
SHOW_PLOTS = True


class SAMReplayStrategyMixin(BaseReplayStrategy):
def __init__(
self,
Expand Down Expand Up @@ -75,22 +77,24 @@ def get_screenshot_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) ->
bbox_list = []
for mask in masks:
bbox_list.append(mask["bbox"])
if SHOW_PLOTS :
plt.figure(figsize=(20,20))
if SHOW_PLOTS:
plt.figure(figsize=(20, 20))
plt.imshow(array_resized)
show_anns(masks)
plt.axis('off')
plt.show()
plt.axis("off")
plt.show()
return str(bbox_list)

def get_click_event_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) -> str:
def get_click_event_bbox(
self, screenshot: Screenshot, show_plots=SHOW_PLOTS
) -> str:
"""
Get the bounding box of the clicked object in a screenshot image in XYWH format.
Args:
screenshot (Screenshot): The screenshot object containing the image.
show_plots (bool): Flag indicating whether to display the plots or not. Defaults to SHOW_PLOTS.
Returns:
str: A string representation of a list containing the bounding box of the clicked object.
None: If the screenshot does not represent a click event with the mouse pressed.
Expand Down Expand Up @@ -122,10 +126,10 @@ def get_click_event_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) ->
y0 = np.min(rows)
x1 = np.max(cols)
y1 = np.max(rows)
w = x1-x0
h = y1-y0
w = x1 - x0
h = y1 - y0
input_box = [x0, y0, w, h]
if SHOW_PLOTS :
if SHOW_PLOTS:
plt.figure(figsize=(10, 10))
plt.imshow(array_resized)
show_mask(best_mask, plt.gca())
Expand Down Expand Up @@ -193,17 +197,24 @@ def show_box(box, ax):
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
)


def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann['segmentation']
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
ax.imshow(img)

0 comments on commit abcd31d

Please sign in to comment.