Skip to content

Commit

Permalink
feat(SegmentReplayStrategy, drivers): add strategies.replay; refactor…
Browse files Browse the repository at this point in the history
… adapters -> drivers + adapters (#714)

* implemented

* add get_active_window_data parameter include_window_data; fix ActionEvent.from_dict to handle multiple separators; add test_models.py

* update get_default_prompt_adapter

* add TODO

* tests.openadapt.adapters -> drivers

* utils.get_marked_image, .extract_code_block; .strip_backticks

* working segment.py (misses final click in calculator task)

* include_replay_instructions; dev_mode=False

* fix test_openai.py: ValueError -> Exception

* replay.py --record -> --capture

* black/flake8

* remove import

* INCLUDE_CURRENT_SCREENSHOT; handle mouse events without x/y

* add models.Replay; print_config in replay.py
  • Loading branch information
abrichr authored Jun 22, 2024
1 parent c674678 commit 7ef115a
Show file tree
Hide file tree
Showing 22 changed files with 740 additions and 100 deletions.
10 changes: 4 additions & 6 deletions openadapt/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@

from openadapt.config import config

from . import anthropic, google, openai, replicate, som, ultralytics
from . import prompt, replicate, som, ultralytics


# TODO: remove
def get_default_prompt_adapter() -> ModuleType:
"""Returns the default prompt adapter module.
Returns:
The module corresponding to the default prompt adapter.
"""
return {
"openai": openai,
"anthropic": anthropic,
"google": google,
}[config.DEFAULT_ADAPTER]
return prompt


# TODO: refactor to follow adapters.prompt
def get_default_segmentation_adapter() -> ModuleType:
"""Returns the default image segmentation adapter module.
Expand Down
47 changes: 47 additions & 0 deletions openadapt/adapters/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Adapter for prompting foundation models."""

from loguru import logger
from typing import Type
from PIL import Image


from openadapt.drivers import anthropic, google, openai


# Define a list of drivers in the order they should be tried
DRIVER_ORDER: list[Type] = [openai, google, anthropic]


def prompt(
text: str,
images: list[Image.Image] | None = None,
system_prompt: str | None = None,
) -> str:
"""Attempt to fetch a prompt completion from various services in order of priority.
Args:
text: The main text prompt.
images: list of images to include in the prompt.
system_prompt: An optional system-level prompt.
Returns:
The result from the first successful driver.
"""
text = text.strip()
for driver in DRIVER_ORDER:
try:
logger.info(f"Trying driver: {driver.__name__}")
return driver.prompt(text, images=images, system_prompt=system_prompt)
except Exception as e:
logger.exception(e)
logger.error(f"Driver {driver.__name__} failed with error: {e}")
import ipdb

ipdb.set_trace()
continue
raise Exception("All drivers failed to provide a response")


if __name__ == "__main__":
# This could be extended to use command-line arguments or other input methods
print(prompt("Describe the solar system."))
24 changes: 20 additions & 4 deletions openadapt/adapters/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,26 @@ def do_fastsam(
retina_masks: bool = True,
imgsz: int | tuple[int, int] | None = 1024,
# threshold below which boxes will be filtered out
conf: float = 0.4,
min_confidence_threshold: float = 0.4,
# discards all overlapping boxes with IoU > iou_threshold
iou: float = 0.9,
max_iou_threshold: float = 0.9,
) -> Image:
"""Get segmented image via FastSAM.
For usage of thresholds see:
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
/ultralytics/utils/ops.py#L164
Args:
TODO
min_confidence_threshold (float, optional): The minimum confidence score
that a detection must meet or exceed to be considered valid. Detections
below this threshold will not be marked. Defaults to 0.00.
max_iou_threshold (float, optional): The maximum allowed Intersection over
Union (IoU) value for overlapping detections. Detections that exceed this
IoU threshold are considered for suppression, keeping only the
detection with the highest confidence. Defaults to 0.05.
"""
model = FastSAM(model_name)

imgsz = imgsz or image.size
Expand All @@ -91,8 +107,8 @@ def do_fastsam(
device=device,
retina_masks=retina_masks,
imgsz=imgsz,
conf=conf,
iou=iou,
conf=min_confidence_threshold,
iou=max_iou_threshold,
)

# Prepare a Prompt Process object
Expand Down
File renamed without changes.
File renamed without changes.
23 changes: 12 additions & 11 deletions openadapt/adapters/openai.py → openadapt/drivers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def get_response(
headers=headers,
json=payload,
)
return response
result = response.json()
if "error" in result:
error = result["error"]
message = error["message"]
raise Exception(message)
return result


def get_completion(payload: dict, dev_mode: bool = False) -> str:
Expand All @@ -136,23 +141,19 @@ def get_completion(payload: dict, dev_mode: bool = False) -> str:
Returns:
(str) first message from the response
"""
response = get_response(payload)
response.raise_for_status()
result = response.json()
logger.info(f"result=\n{pformat(result)}")
if "error" in result:
error = result["error"]
message = error["message"]
# TODO: fail after maximum number of attempts
if "retry your request" in message:
try:
result = get_response(payload)
except Exception as exc:
if "retry your request" in str(exc):
return get_completion(payload)
elif dev_mode:
import ipdb

ipdb.set_trace()
# TODO: handle more errors
else:
raise ValueError(result["error"]["message"])
raise exc
logger.info(f"result=\n{pformat(result)}")
choices = result["choices"]
choice = choices[0]
message = choice["message"]
Expand Down
16 changes: 14 additions & 2 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,11 @@ def from_dict(
suffix_len = len(name_suffix)

key_names = utils.split_by_separators(
action_dict["text"][prefix_len:-suffix_len],
action_dict.get("text", "")[prefix_len:-suffix_len],
key_seps,
)
canonical_key_names = utils.split_by_separators(
action_dict["canonical_text"][prefix_len:-suffix_len],
action_dict.get("canonical_text", "")[prefix_len:-suffix_len],
key_seps,
)
logger.info(f"{key_names=}")
Expand Down Expand Up @@ -920,6 +920,18 @@ def asdict(self) -> dict:
}


class Replay(db.Base):
"""Class representing a replay in the database."""

__tablename__ = "replay"

id = sa.Column(sa.Integer, primary_key=True)
timestamp = sa.Column(ForceFloat)
strategy_name = sa.Column(sa.String)
strategy_args = sa.Column(sa.JSON)
git_hash = sa.Column(sa.String)


def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base:
"""Copy a SQLAlchemy instance.
Expand Down
5 changes: 4 additions & 1 deletion openadapt/playback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def play_mouse_event(event: ActionEvent, mouse_controller: mouse.Controller) ->
pressed = event.mouse_pressed
logger.debug(f"{name=} {x=} {y=} {dx=} {dy=} {button_name=} {pressed=}")

mouse_controller.position = (x, y)
if all([val is not None for val in (x, y)]):
mouse_controller.position = (x, y)
else:
logger.warning(f"{x=} {y=}")
if name == "move":
pass
elif name == "click":
Expand Down
56 changes: 55 additions & 1 deletion openadapt/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import matplotlib.pyplot as plt
import numpy as np

from openadapt import common, models, utils
from openadapt import common, contrib, models, utils
from openadapt.config import PERFORMANCE_PLOTS_DIR_PATH, config
from openadapt.models import ActionEvent

Expand Down Expand Up @@ -764,3 +764,57 @@ def plot_segments(
plt.imshow(image)
plt.axis("off")
plt.show()


def get_marked_image(
original_image: Image.Image,
masks: list[np.ndarray],
include_masks: bool = True,
include_marks: bool = True,
) -> Image.Image:
"""Get a Set-of-Mark image using the original SoM visualizer.
Args:
original_image (Image.Image): The original PIL image.
masks (list[np.ndarray]): A list of masks representing segments in the
original image.
include_masks (bool, optional): If True, masks will be included in the
output visualizations. Defaults to True.
include_marks (bool, optional): If True, marks will be included in the
output visualizations. Defaults to True.
Returns:
Image.Image: The marked image, where marks and/or masks are applied based on
the specified confidence and IoU thresholds and the include flags.
"""
image_arr = np.asarray(original_image)

# The rest of this function is copied from
# github.com/microsoft/SoM/blob/main/task_adapter/sam/tasks/inference_sam_m2m_auto.py

# metadata = MetadataCatalog.get('coco_2017_train_panoptic')
metadata = None
visual = contrib.som.visualizer.Visualizer(image_arr, metadata=metadata)
mask_map = np.zeros(image_arr.shape, dtype=np.uint8)
label_mode = "1"
alpha = 0.1
anno_mode = []
if include_masks:
anno_mode.append("Mask")
if include_marks:
anno_mode.append("Mark")
for i, mask in enumerate(masks):
label = i + 1
demo = visual.draw_binary_mask_with_number(
mask,
text=str(label),
label_mode=label_mode,
alpha=alpha,
anno_mode=anno_mode,
)
mask_map[mask == 1] = label

im = demo.get_image()
marked_image = Image.fromarray(im)

return marked_image
16 changes: 16 additions & 0 deletions openadapt/prompts/describe_recording--segment.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Consider the actions in the recording and states of the active window immediately
before each action was taken:

```json
{{ action_windows }}
```

Consider the attached screenshots taken immediately before each action. The order of
the screenshots matches the order of the actions above.

Provide a detailed natural language description of everything that happened
in this recording. This description will be embedded in the context for a future prompt
to replay the recording (subject to proposed modifications in natural language) on a
live system, so make sure to include everything you will need to know.

My career depends on this. Lives are at stake.
73 changes: 73 additions & 0 deletions openadapt/prompts/generate_action_event--segment.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
{% if include_raw_recording %}
Consider the previously recorded actions:

```json
{{ recorded_actions }}
```
{% endif %}


{% if include_raw_recording_description %}
Consider the following description of the previously recorded actions:

``json
{{ recording_description }}
```
{% endif %}


{% if include_replay_instructions %}
Consider the user's proposed modifications in natural language instructions:

```text
{{ replay_instructions }}
```
{% endif %}


{% if include_modified_recording %}
Consider this updated list of actions that have been modified such that replaying them
would have accomplished the user's instructions:

```json
{{ modified_actions }}
```
{% endif %}


{% if include_modified_recording_description %}
Consider the following description of the updated list of actions that have been
modified such that replaying them would have accomplished the user's instructions:

``json
{{ modified_recording_description }}
```
{% endif %}


Consider the actions you've produced (and we have played back) so far:

```json
{{ replayed_actions }}
```

{% if include_active_window %}
Consider the current active window:
```json
{{ current_window }}
```
{% endif %}


The attached image is a screenshot of the current state of the system.

Provide the next action to be replayed in order to accomplish the user's replay
instructions.

Do NOT provide available_segment_descriptions in your response.

Respond with JSON and nothing else.

If you wish to terminate the recording, return an empty object.

My career depends on this. Lives are at stake.
Loading

0 comments on commit 7ef115a

Please sign in to comment.