Skip to content

Commit

Permalink
Merge pull request #169 from jesicasusanto/feat/sam_mixin
Browse files Browse the repository at this point in the history
add sam_mixin.py
  • Loading branch information
abrichr authored Jun 15, 2023
2 parents 2c0e860 + e3b4056 commit 8fd1af7
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 5 deletions.
18 changes: 14 additions & 4 deletions openadapt/strategies/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from loguru import logger
import numpy as np
from openadapt.crud import get_screenshots

from openadapt.events import get_events
from openadapt.models import Recording, Screenshot, WindowEvent
Expand All @@ -19,13 +20,15 @@

from openadapt.strategies.mixins.ocr import OCRReplayStrategyMixin
from openadapt.strategies.mixins.ascii import ASCIIReplayStrategyMixin
from openadapt.strategies.mixins.sam import SAMReplayStrategyMixin
from openadapt.strategies.mixins.summary import SummaryReplayStrategyMixin


class DemoReplayStrategy(
HuggingFaceReplayStrategyMixin,
OCRReplayStrategyMixin,
ASCIIReplayStrategyMixin,
SAMReplayStrategyMixin,
SummaryReplayStrategyMixin,
BaseReplayStrategy,
):
Expand All @@ -35,6 +38,8 @@ def __init__(
):
super().__init__(recording)
self.result_history = []
self.screenshots = get_screenshots(recording)
self.screenshot_idx = 0

def get_next_action_event(
self,
Expand All @@ -47,6 +52,11 @@ def get_next_action_event(
ocr_text = self.get_ocr_text(screenshot)
# logger.info(f"ocr_text=\n{ocr_text}")

screenshot_bbox = self.get_screenshot_bbox(screenshot)
logger.info(f"screenshot_bbox=\n{screenshot_bbox}")

screenshot_click_event_bbox = self.get_click_event_bbox(self.screenshots[self.screenshot_idx])
logger.info(f"self.screenshots[self.screenshot_idx].action_event=\n{screenshot_click_event_bbox}")
event_strs = [
f"<{event}>"
for event in self.recording.action_events
Expand All @@ -58,16 +68,16 @@ def get_next_action_event(
prompt = " ".join(event_strs + history_strs)
N = max(0, len(prompt) - MAX_INPUT_SIZE)
prompt = prompt[N:]
logger.info(f"{prompt=}")
#logger.info(f"{prompt=}")
max_tokens = 10
completion = self.get_completion(prompt, max_tokens)
logger.info(f"{completion=}")
#logger.info(f"{completion=}")

# only take the first <...>
result = completion.split(">")[0].strip(" <>")
logger.info(f"{result=}")
#logger.info(f"{result=}")
self.result_history.append(result)

# TODO: parse result into ActionEvent(s)

self.screenshot_idx += 1
return None
235 changes: 235 additions & 0 deletions openadapt/strategies/mixins/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Implements a ReplayStrategy mixin for getting segmenting images via SAM model.
Uses SAM model:https://github.com/facebookresearch/segment-anything
Usage:
class MyReplayStrategy(SAMReplayStrategyMixin):
...
"""
from pprint import pformat
from mss import mss
import numpy as np
from openadapt import models
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
from PIL import Image
from loguru import logger
from openadapt.events import get_events
from openadapt.utils import display_event, rows2dicts
from openadapt.models import Recording, Screenshot, WindowEvent
from pathlib import Path
import urllib
import numpy as np
import matplotlib.pyplot as plt

from openadapt.strategies.base import BaseReplayStrategy

CHECKPOINT_URL_BASE = "https://dl.fbaipublicfiles.com/segment_anything/"
CHECKPOINT_URL_BY_NAME = {
"default": f"{CHECKPOINT_URL_BASE}sam_vit_h_4b8939.pth",
"vit_l": f"{CHECKPOINT_URL_BASE}sam_vit_l_0b3195.pth",
"vit_b": f"{CHECKPOINT_URL_BASE}sam_vit_b_01ec64.pth",
}
MODEL_NAME = "default"
CHECKPOINT_DIR_PATH = "./checkpoints"
RESIZE_RATIO = 0.1
SHOW_PLOTS = True


class SAMReplayStrategyMixin(BaseReplayStrategy):
def __init__(
self,
recording: Recording,
model_name=MODEL_NAME,
checkpoint_dir_path=CHECKPOINT_DIR_PATH,
):
super().__init__(recording)
self.sam_model = self._initialize_model(model_name, checkpoint_dir_path)
self.sam_predictor = SamPredictor(self.sam_model)
self.sam_mask_generator = SamAutomaticMaskGenerator(self.sam_model)

def _initialize_model(self, model_name, checkpoint_dir_path):
checkpoint_url = CHECKPOINT_URL_BY_NAME[model_name]
checkpoint_file_name = checkpoint_url.split("/")[-1]
checkpoint_file_path = Path(checkpoint_dir_path, checkpoint_file_name)
if not Path.exists(checkpoint_file_path):
Path(checkpoint_dir_path).mkdir(parents=True, exist_ok=True)
logger.info(f"downloading {checkpoint_url=} to {checkpoint_file_path=}")
urllib.request.urlretrieve(checkpoint_url, checkpoint_file_path)
return sam_model_registry[model_name](checkpoint=checkpoint_file_path)

def get_screenshot_bbox(self, screenshot: Screenshot, show_plots=SHOW_PLOTS) -> str:
"""
Get the bounding boxes of objects in a screenshot image with RESIZE_RATIO in XYWH format.
Args:
screenshot (Screenshot): The screenshot object containing the image.
show_plots (bool): Flag indicating whether to display the plots or not. Defaults to SHOW_PLOTS.
Returns:
str: A string representation of a list containing the bounding boxes of objects.
"""
image_resized = resize_image(screenshot.image)
array_resized = np.array(image_resized)
masks = self.sam_mask_generator.generate(array_resized)
bbox_list = []
for mask in masks:
bbox_list.append(mask["bbox"])
if SHOW_PLOTS:
plt.figure(figsize=(20, 20))
plt.imshow(array_resized)
show_anns(masks)
plt.axis("off")
plt.show()
return str(bbox_list)

def get_click_event_bbox(
self, screenshot: Screenshot, show_plots=SHOW_PLOTS
) -> str:
"""
Get the bounding box of the clicked object in a resized image with RESIZE_RATIO in XYWH format.
Args:
screenshot (Screenshot): The screenshot object containing the image.
show_plots (bool): Flag indicating whether to display the plots or not. Defaults to SHOW_PLOTS.
Returns:
str: A string representation of a list containing the bounding box of the clicked object.
None: If the screenshot does not represent a click event with the mouse pressed.
"""
for action_event in screenshot.action_event:
if action_event.name in "click" and action_event.mouse_pressed == True:
logger.info(f"click_action_event=\n{action_event}")
image_resized = resize_image(screenshot.image)
array_resized = np.array(image_resized)

# Resize mouse coordinates
resized_mouse_x = int(action_event.mouse_x * RESIZE_RATIO)
resized_mouse_y = int(action_event.mouse_y * RESIZE_RATIO)
# Add additional points around the clicked point
additional_points = [
[resized_mouse_x - 1, resized_mouse_y - 1], # Top-left
[resized_mouse_x - 1, resized_mouse_y], # Left
[resized_mouse_x - 1, resized_mouse_y + 1], # Bottom-left
[resized_mouse_x, resized_mouse_y - 1], # Top
[resized_mouse_x, resized_mouse_y], # Center (clicked point)
[resized_mouse_x, resized_mouse_y + 1], # Bottom
[resized_mouse_x + 1, resized_mouse_y - 1], # Top-right
[resized_mouse_x + 1, resized_mouse_y], # Right
[resized_mouse_x + 1, resized_mouse_y + 1], # Bottom-right
]
input_point = np.array(additional_points)
self.sam_predictor.set_image(array_resized)
input_labels = np.ones(
input_point.shape[0]
) # Set labels for additional points
masks, scores, _ = self.sam_predictor.predict(
point_coords=input_point,
point_labels=input_labels,
multimask_output=True,
)
best_mask_index = np.argmax(scores)
best_mask = masks[best_mask_index]
rows, cols = np.where(best_mask)
# Calculate bounding box coordinates
x0 = np.min(cols)
y0 = np.min(rows)
x1 = np.max(cols)
y1 = np.max(rows)
w = x1 - x0
h = y1 - y0
input_box = [x0, y0, w, h]
if SHOW_PLOTS:
plt.figure(figsize=(10, 10))
plt.imshow(array_resized)
show_mask(best_mask, plt.gca())
show_box(input_box, plt.gca())
# for point in additional_points :
# show_points(np.array([point]),input_labels,plt.gca())
show_points(input_point, input_labels, plt.gca())
plt.axis("on")
plt.show()
return input_box
return []


def resize_image(image: Image) -> Image:
"""
Resize the given image.
Args:
image (PIL.Image.Image): The image to be resized.
Returns:
PIL.Image.Image: The resized image.
"""
new_size = [int(dim * RESIZE_RATIO) for dim in image.size]
image_resized = image.resize(new_size)
return image_resized


def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=120):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)


def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2], box[3]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
)


def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
2 changes: 1 addition & 1 deletion openadapt/strategies/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ def get_window_state_diffs(
window_event_states, window_event_states[1:]
)
]
return diffs
return diffs

0 comments on commit 8fd1af7

Please sign in to comment.