Skip to content

Commit

Permalink
add bbox functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed May 26, 2023
1 parent 6b6b287 commit 66a28a8
Showing 1 changed file with 64 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,74 @@ def _initialize_model(self, model_name, checkpoint_dir_path):
return sam_model_registry[model_name](checkpoint=checkpoint_file_path)


def get_screenshot_bbox(self, screenshot: Screenshot) -> Screenshot:
#logger.info("before auto generate masks\n")
#out.append({"size": [h, w], "counts": counts})
#resize sct_img
image = screenshot.image
resize_ratio = 0.1

new_size = [ int(dim * resize_ratio) for dim in image.size]
print(new_size)
image_resized = image.resize(new_size)
new_array = np.array(image_resized)
masks = self.sam_mask_generator.generate(new_array)
def get_screenshot_bbox(self, screenshot: Screenshot) -> str:
"""
Get the bounding boxes of objects in a screenshot image.
Args:
screenshot (Screenshot): The screenshot object containing the image.
Returns:
str: A string representation of a list containing the bounding boxes of objects.
"""
image_resized = resize_image(screenshot.image)
array_resized = np.array(image_resized)
masks = self.sam_mask_generator.generate(array_resized)
logger.info(f"{masks=}")
bbox_list = []
for mask in masks :
bbox_list.append(mask['bbox'])
return str(bbox_list)



def get_click_event_bbox(self, screenshot: Screenshot) -> str:
"""
Get the bounding box of the clicked object in a screenshot image.
Args:
screenshot (Screenshot): The screenshot object containing the image.
Returns:
str: A string representation of a list containing the bounding box of the clicked object.
None: If the screenshot does not represent a click event with the mouse pressed.
"""
if screenshot.action_event.name == "click" and screenshot.action_event.mouse_pressed == True :
image_resized = resize_image(screenshot.image)
array_resized = np.array(image_resized)
self.sam_predictor.set_image(array_resized)
input_point = np.array([[screenshot.action_event.mouse_x, screenshot.action_event.mouse_y]])
input_label = np.array([1])
masks, scores, logits = self.sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
best_mask = masks[np.argmax(scores), :, :] # Get the best mask with highest score
# Find foreground pixel coordinates
rows, cols = np.where(best_mask)
# Calculate bounding box coordinates
x0 = np.min(cols)
y0 = np.min(rows)
w = np.max(cols) - x0
h = np.max(rows) - y0
return str([x0,y0,w,h])
else :
return None

def resize_image(image : Image) -> Image:
"""
Resize the given image.
Args:
image (PIL.Image.Image): The image to be resized.
Returns:
PIL.Image.Image: The resized image.
"""
resize_ratio = 0.1
new_size = [ int(dim * resize_ratio) for dim in image.size]
image_resized = image.resize(new_size)
return image_resized

0 comments on commit 66a28a8

Please sign in to comment.