Skip to content

Commit

Permalink
fix(replay): fix poetry lock, ulralytics.py, update ERROR_REPORTING_DSN
Browse files Browse the repository at this point in the history
* fix dependencies, update ultralytics.py

* update ERROR_REPORTING_DSN

* black
  • Loading branch information
abrichr authored Oct 28, 2024
1 parent cf3a57b commit 0ce11f3
Show file tree
Hide file tree
Showing 3 changed files with 3,323 additions and 2,771 deletions.
132 changes: 19 additions & 113 deletions openadapt/adapters/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@


from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
from ultralytics.models.fastsam import FastSAMPredictor
from ultralytics.models.sam import Predictor as SAMPredictor
import fire
import numpy as np
import ultralytics

from openadapt import cache
Expand All @@ -41,13 +40,11 @@
SAM_MODEL_NAMES = (
"sam_b.pt", # base
"sam_l.pt", # large
# "mobile_sam.pt",
)
MODEL_NAMES = FASTSAM_MODEL_NAMES + SAM_MODEL_NAMES
DEFAULT_MODEL_NAME = MODEL_NAMES[0]


# TODO: rename
def fetch_segmented_image(
image: Image.Image,
model_name: str = DEFAULT_MODEL_NAME,
Expand All @@ -74,14 +71,12 @@ def fetch_segmented_image(
def do_fastsam(
image: Image,
model_name: str,
# TODO: inject from config
device: str = "cpu",
retina_masks: bool = True,
imgsz: int | tuple[int, int] | None = 1024,
# threshold below which boxes will be filtered out
min_confidence_threshold: float = 0.4,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold: float = 0.9,
max_det: int = 1000,
max_retries: int = 5,
retry_delay_seconds: float = 0.1,
) -> Image:
Expand All @@ -90,100 +85,35 @@ def do_fastsam(
For usage of thresholds see:
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
/ultralytics/utils/ops.py#L164
Args:
TODO
min_confidence_threshold (float, optional): The minimum confidence score
that a detection must meet or exceed to be considered valid. Detections
below this threshold will not be marked. Defaults to 0.00.
max_iou_threshold (float, optional): The maximum allowed Intersection over
Union (IoU) value for overlapping detections. Detections that exceed this
IoU threshold are considered for suppression, keeping only the
detection with the highest confidence. Defaults to 0.05.
"""
model = FastSAM(model_name)

imgsz = imgsz or image.size

# Run inference on image
everything_results = model(
image,
device=device,
retina_masks=retina_masks,
imgsz=imgsz,
conf=min_confidence_threshold,
iou=max_iou_threshold,
max_det=max_det,
)

# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(image, everything_results, device="cpu")

# Everything prompt
annotations = prompt_process.everything_prompt()

# TODO: support other modes once issues are fixed
# https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103

# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
# annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300])

# Text prompt
# annotations = prompt_process.text_prompt(text='a photo of a dog')

# Point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
# annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])

assert len(annotations) == 1, len(annotations)
annotation = annotations[0]

# hide original image
annotation.orig_img = np.ones(annotation.orig_img.shape)

# TODO: in memory, e.g. with prompt_process.fast_show_mask()
with TemporaryDirectory() as tmp_dir:
# Force the output format to PNG to prevent JPEG compression artefacts
annotation.path = annotation.path.replace(".jpg", ".png")
prompt_process.plot(
[annotation],
tmp_dir,
with_contours=False,
retina=False,
assert len(everything_results) == 1, len(everything_results)
annotation = everything_results[0]

segmented_image = Image.fromarray(
annotation.plot(
img=np.ones(annotation.orig_img.shape, dtype=annotation.orig_img.dtype),
kpt_line=False,
labels=False,
boxes=False,
probs=False,
color_mode="instance",
)
result_name = os.path.basename(annotation.path)
logger.info(f"{annotation.path=}")
segmented_image_path = Path(tmp_dir) / result_name
segmented_image = Image.open(segmented_image_path)

# Ensure the image is fully loaded before deletion to avoid errors or incomplete operations,
# as some operating systems and file systems lock files during read or processing.
segmented_image.load()

# Attempt to delete the file with retries and delay
retries = 0

while retries < max_retries:
try:
os.remove(segmented_image_path)
break # If deletion succeeds, exit loop
except OSError as e:
if e.errno == errno.ENOENT: # File not found
break
else:
retries += 1
time.sleep(retry_delay_seconds)

if retries == max_retries:
logger.warning(f"Failed to delete {segmented_image_path}")
# Check if the dimensions of the original and segmented images differ
# XXX TODO this is a hack, this plotting code should be refactored, but the
# bug may exist in ultralytics, since they seem to resize as well; see:
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/plotting.py#L238
# https://github.com/ultralytics/ultralytics/issues/561#issuecomment-1403079910
)

if image.size != segmented_image.size:
logger.warning(f"{image.size=} != {segmented_image.size=}, resizing...")
# Resize segmented_image to match original using nearest neighbor interpolation
segmented_image = segmented_image.resize(image.size, Image.NEAREST)

assert image.size == segmented_image.size, (image.size, segmented_image.size)
Expand All @@ -194,7 +124,6 @@ def do_fastsam(
def do_sam(
image: Image.Image,
model_name: str,
# TODO: add params
) -> Image.Image:
# Create SAMPredictor
overrides = dict(
Expand All @@ -207,20 +136,7 @@ def do_sam(
predictor = SAMPredictor(overrides=overrides)

# Segment with additional args
# results = predictor(source=image, crop_n_layers=1, points_stride=64)
results = predictor(
source=image,
# crop_n_layers=3,
# crop_overlap_ratio=0.5,
# crop_downscale_factor=1,
# point_grids=None,
# points_stride=12,
# points_batch_size=128,
# conf_thres=0.8,
# stability_score_thresh=0.95,
# stability_score_offset=0.95,
# crop_nms_thresh=0.8,
)
results = predictor(source=image)
mask_ims = results_to_mask_images(results)
segmented_image = colorize_masks(mask_ims)
return segmented_image
Expand All @@ -238,8 +154,7 @@ def results_to_mask_images(


def colorize_masks(masks: list[Image.Image]) -> Image.Image:
"""
Takes a list of PIL images containing binary masks and returns a new PIL.Image
"""Takes a list of PIL images containing binary masks and returns a new PIL.Image
where each mask is colored differently using a unique color for each mask.
Args:
Expand All @@ -249,15 +164,11 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
PIL.Image: A new image with each mask in a different color.
"""
if not masks:
return None # Return None if the list is empty
return None

# Assuming all masks are the same size, get dimensions
width, height = masks[0].size

# Create an empty array with 3 color channels (RGB)
result_image = np.zeros((height, width, 3), dtype=np.uint8)

# Generate unique colors using HSV color space
num_masks = len(masks)
colors = [
tuple(
Expand All @@ -271,17 +182,12 @@ def colorize_masks(masks: list[Image.Image]) -> Image.Image:
]

for idx, mask in enumerate(masks):
# Convert PIL Image to numpy array
mask_array = np.array(mask)

# Apply the color to the mask
for c in range(3):
# Only colorize where the mask is True (assuming mask is binary: 0 or 255)
result_image[:, :, c] += (mask_array / 255 * colors[idx][c]).astype(
np.uint8
)

# Convert the result back to a PIL image
return Image.fromarray(result_image)


Expand Down
3 changes: 2 additions & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class SegmentationAdapter(str, Enum):
# Error reporting
ERROR_REPORTING_ENABLED: bool = True
ERROR_REPORTING_DSN: ClassVar = (
"https://[email protected]/3798"
# "https://[email protected]/3798"
"https://[email protected]/8771",
)
ERROR_REPORTING_BRANCH: ClassVar = "main"

Expand Down
Loading

0 comments on commit 0ce11f3

Please sign in to comment.