Skip to content

Commit

Permalink
#103 : Refactor InputEvent => ActionEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
0dm committed May 10, 2023
1 parent 659b9d7 commit 7e537d4
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 162 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ python -m puterbot.record "testing out puterbot"
Wait until all three event writers have started:
```
| INFO | __mp_main__:write_events:230 - event_type='screen' starting
| INFO | __mp_main__:write_events:230 - event_type='input' starting
| INFO | __mp_main__:write_events:230 - event_type='action' starting
| INFO | __mp_main__:write_events:230 - event_type='window' starting
```

Expand Down Expand Up @@ -92,12 +92,12 @@ More ReplayStrategies coming soon! (see [Contributing](#Contributing)).

Our goal is to automate the task described and demonstrated in a `Recording`.
That is, given a new `Screenshot`, we want to generate the appropriate
`InputEvent`(s) based on the previously recorded `InputEvent`s in order to
`ActionEvent`(s) based on the previously recorded `ActionEvent`s in order to
accomplish the task specified in the `Recording.task_description`, while
accounting for differences in screen resolution, window size, application
behavior, etc.

If it's not clear what `InputEvent` is appropriate for the given `Screenshot`,
If it's not clear what `ActionEvent` is appropriate for the given `Screenshot`,
(e.g. if the GUI application is behaving in a way we haven't seen before),
we can ask the user to take over temporarily to demonstrate the appropriate
course of action.
Expand All @@ -107,9 +107,9 @@ course of action.
The dataset consists of the following entities:
1. `Recording`: Contains information about the screen dimensions, platform, and
other metadata.
2. `InputEvent`: Represents a user input event such as a mouse click or key
press. Each `InputEvent` has an associated `Screenshot` taken immediately
before the event occurred. `InputEvent`s are aggregated to remove
2. `ActionEvent`: Represents a user action event such as a mouse click or key
press. Each `ActionEvent` has an associated `Screenshot` taken immediately
before the event occurred. `ActionEvent`s are aggregated to remove
unnecessary events (see [visualize](#visualize).)
3. `Screenshot`: Contains the PNG data of a screenshot taken during the
recording.
Expand All @@ -119,7 +119,7 @@ The dataset consists of the following entities:
You can assume that you have access to the following functions:
- `create_recording("doing taxes")`: Creates a recording.
- `get_latest_recording()`: Gets the latest recording.
- `get_events(recording)`: Returns a list of `InputEvent` objects for the given
- `get_events(recording)`: Returns a list of `ActionEvent` objects for the given
recording.

### Instructions
Expand All @@ -142,7 +142,7 @@ feedback and iterate on the approach.
Your submission will be evaluated based on the following criteria:

1. **Functionality** : Your implementation should correctly generate the new
`InputEvent` objects that can be replayed in order to accomplish the task in
`ActionEvent` objects that can be replayed in order to accomplish the task in
the original recording.

2. **Code Quality** : Your code should be well-structured, clean, and easy to
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""rename input_event to action_event
Revision ID: 20f9c2afb42c
Revises: 5139d7df38f6
Create Date: 2023-05-10 11:22:37.266810
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '20f9c2afb42c'
down_revision = '5139d7df38f6'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.rename_table('input_event', 'action_event')


def downgrade() -> None:
op.rename_table('action_event', 'input_event')
12 changes: 6 additions & 6 deletions puterbot/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import sqlalchemy as sa

from puterbot.db import Session
from puterbot.models import InputEvent, Screenshot, Recording, WindowEvent
from puterbot.models import ActionEvent, Screenshot, Recording, WindowEvent


BATCH_SIZE = 1

db = Session()
input_events = []
action_events = []
screenshots = []
window_events = []

Expand Down Expand Up @@ -42,13 +42,13 @@ def _insert(event_data, table, buffer=None):
return result


def insert_input_event(recording_timestamp, event_timestamp, event_data):
def insert_action_event(recording_timestamp, event_timestamp, event_data):
event_data = {
**event_data,
"timestamp": event_timestamp,
"recording_timestamp": recording_timestamp,
}
_insert(event_data, InputEvent, input_events)
_insert(event_data, ActionEvent, action_events)


def insert_screenshot(recording_timestamp, event_timestamp, event_data):
Expand Down Expand Up @@ -97,8 +97,8 @@ def _get(table, recording_timestamp):
)


def get_input_events(recording):
return _get(InputEvent, recording.timestamp)
def get_action_events(recording):
return _get(ActionEvent, recording.timestamp)


def get_screenshots(recording, precompute_diffs=True):
Expand Down
86 changes: 43 additions & 43 deletions puterbot/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from puterbot.common import KEY_EVENTS, MOUSE_EVENTS
from puterbot.crud import (
get_input_events,
get_action_events,
get_window_events,
get_screenshots,
)
from puterbot.models import InputEvent
from puterbot.models import ActionEvent
from puterbot.utils import (
get_double_click_distance_pixels,
get_double_click_interval_seconds,
Expand All @@ -25,42 +25,42 @@

def get_events(recording, process=True, meta=None):
start_time = time.time()
input_events = get_input_events(recording)
action_events = get_action_events(recording)
window_events = get_window_events(recording)
screenshots = get_screenshots(recording)

raw_input_event_dicts = rows2dicts(input_events)
logger.debug(f"raw_input_event_dicts=\n{pformat(raw_input_event_dicts)}")
raw_action_event_dicts = rows2dicts(action_events)
logger.debug(f"raw_action_event_dicts=\n{pformat(raw_action_event_dicts)}")

num_input_events = len(input_events)
num_action_events = len(action_events)
num_window_events = len(window_events)
num_screenshots = len(screenshots)

num_input_events_raw = num_input_events
num_action_events_raw = num_action_events
num_window_events_raw = num_window_events
num_screenshots_raw = num_screenshots
duration_raw = input_events[-1].timestamp - input_events[0].timestamp
duration_raw = action_events[-1].timestamp - action_events[0].timestamp

num_process_iters = 0
if process:
while True:
logger.info(
f"{num_process_iters=} "
f"{num_input_events=} "
f"{num_action_events=} "
f"{num_window_events=} "
f"{num_screenshots=}"
)
input_events, window_events, screenshots = process_events(
input_events, window_events, screenshots,
action_events, window_events, screenshots = process_events(
action_events, window_events, screenshots,
)
if (
len(input_events) == num_input_events and
len(action_events) == num_action_events and
len(window_events) == num_window_events and
len(screenshots) == num_screenshots
):
break
num_process_iters += 1
num_input_events = len(input_events)
num_action_events = len(action_events)
num_window_events = len(window_events)
num_screenshots = len(screenshots)
if num_process_iters == MAX_PROCESS_ITERS:
Expand All @@ -71,8 +71,8 @@ def get_events(recording, process=True, meta=None):
lambda num, raw_num: f"{num} of {raw_num} ({(num / raw_num):.2%})"
)
meta["num_process_iters"] = num_process_iters
meta["num_input_events"] = format_num(
num_input_events, num_input_events_raw,
meta["num_action_events"] = format_num(
num_action_events, num_action_events_raw,
)
meta["num_window_events"] = format_num(
num_window_events, num_window_events_raw,
Expand All @@ -81,16 +81,16 @@ def get_events(recording, process=True, meta=None):
num_screenshots, num_screenshots_raw,
)

duration = input_events[-1].timestamp - input_events[0].timestamp
if len(input_events) > 1:
duration = action_events[-1].timestamp - action_events[0].timestamp
if len(action_events) > 1:
assert duration > 0, duration
meta["duration"] = format_num(duration, duration_raw)

end_time = time.time()
duration = end_time - start_time
logger.info(f"{duration=}")

return input_events # , window_events, screenshots
return action_events # , window_events, screenshots


def make_parent_event(child, extra=None):
Expand All @@ -108,7 +108,7 @@ def make_parent_event(child, extra=None):
extra = extra or {}
for key, val in extra.items():
event_dict[key] = val
return InputEvent(**event_dict)
return ActionEvent(**event_dict)


def merge_consecutive_mouse_move_events(events, by_diff_distance=True):
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_merged_events(
return merged_events


return merge_consecutive_input_events(
return merge_consecutive_action_events(
"mouse_move", events, is_target_event, get_merged_events,
)

Expand All @@ -256,7 +256,7 @@ def get_merged_events(to_merge, state):
return [merged_event]


return merge_consecutive_input_events(
return merge_consecutive_action_events(
"mouse_scroll", events, is_target_event, get_merged_events,
)

Expand Down Expand Up @@ -381,7 +381,7 @@ def get_merged_events(to_merge, state):
return merged


return merge_consecutive_input_events(
return merge_consecutive_action_events(
"mouse_click", events, is_target_event, get_merged_events,
)

Expand Down Expand Up @@ -467,7 +467,7 @@ def get_merged_events(to_merge, state):
merged_events.append(merged_event)
return merged_events

return merge_consecutive_input_events(
return merge_consecutive_action_events(
"keyboard", events, is_target_event, get_merged_events,
)

Expand Down Expand Up @@ -530,15 +530,15 @@ def get_merged_events(to_merge, state):
return merged_events


return merge_consecutive_input_events(
return merge_consecutive_action_events(
"redundant_mouse_move", events, is_target_event, get_merged_events,
)


def merge_consecutive_input_events(
def merge_consecutive_action_events(
name, events, is_target_event, get_merged_events,
):
"""Merge consecutive input events into a single event"""
"""Merge consecutive action events into a single event"""

num_events_before = len(events)
state = {"dt": 0}
Expand Down Expand Up @@ -573,11 +573,11 @@ def include_merged_events(to_merge):


def discard_unused_events(
referred_events, input_events, referred_timestamp_key,
referred_events, action_events, referred_timestamp_key,
):
referred_event_timestamps = set([
getattr(input_event, referred_timestamp_key)
for input_event in input_events
getattr(action_event, referred_timestamp_key)
for action_event in action_events
])
num_referred_events_before = len(referred_events)
referred_events = [
Expand All @@ -593,13 +593,13 @@ def discard_unused_events(
return referred_events


def process_events(input_events, window_events, screenshots):
num_input_events = len(input_events)
def process_events(action_events, window_events, screenshots):
num_action_events = len(action_events)
num_window_events = len(window_events)
num_screenshots = len(screenshots)
num_total = num_input_events + num_window_events + num_screenshots
num_total = num_action_events + num_window_events + num_screenshots
logger.info(
f"before {num_input_events=} {num_window_events=} {num_screenshots=} "
f"before {num_action_events=} {num_window_events=} {num_screenshots=} "
f"{num_total=}"
)
process_fns = [
Expand All @@ -610,9 +610,9 @@ def process_events(input_events, window_events, screenshots):
merge_consecutive_mouse_click_events,
]
for process_fn in process_fns:
input_events = process_fn(input_events)
action_events = process_fn(action_events)
# TODO: keep events in which window_event_timestamp is updated
for prev_event, event in zip(input_events, input_events[1:]):
for prev_event, event in zip(action_events, action_events[1:]):
try:
assert prev_event.timestamp <= event.timestamp, (
process_fn, prev_event, event,
Expand All @@ -621,26 +621,26 @@ def process_events(input_events, window_events, screenshots):
logger.exception(exc)
import ipdb; ipdb.set_trace()
window_events = discard_unused_events(
window_events, input_events, "window_event_timestamp",
window_events, action_events, "window_event_timestamp",
)
screenshots = discard_unused_events(
screenshots, input_events, "screenshot_timestamp",
screenshots, action_events, "screenshot_timestamp",
)
num_input_events_ = len(input_events)
num_action_events_ = len(action_events)
num_window_events_ = len(window_events)
num_screenshots_ = len(screenshots)
num_total_ = num_input_events_ + num_window_events_ + num_screenshots_
pct_input_events = num_input_events_ / num_input_events
num_total_ = num_action_events_ + num_window_events_ + num_screenshots_
pct_action_events = num_action_events_ / num_action_events
pct_window_events = num_window_events_ / num_window_events
pct_screenshots = num_screenshots_ / num_screenshots
pct_total = num_total_ / num_total
logger.info(
f"after {num_input_events_=} {num_window_events_=} {num_screenshots_=} "
f"after {num_action_events_=} {num_window_events_=} {num_screenshots_=} "
f"{num_total=}"
)
logger.info(
f"{pct_input_events=} {pct_window_events=} {pct_screenshots=} "
f"{pct_action_events=} {pct_window_events=} {pct_screenshots=} "
f"{pct_total=}"

)
return input_events, window_events, screenshots
return action_events, window_events, screenshots
12 changes: 6 additions & 6 deletions puterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class Recording(Base):
platform = sa.Column(sa.String)
task_description = sa.Column(sa.String)

input_events = sa.orm.relationship("InputEvent", back_populates="recording")
action_events = sa.orm.relationship("ActionEvent", back_populates="recording")


class InputEvent(Base):
__tablename__ = "input_event"
class ActionEvent(Base):
__tablename__ = "action_event"

id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
Expand All @@ -46,10 +46,10 @@ class InputEvent(Base):
canonical_key_name = sa.Column(sa.String)
canonical_key_char = sa.Column(sa.String)
canonical_key_vk = sa.Column(sa.String)
parent_id = sa.Column(sa.Integer, sa.ForeignKey("input_event.id"))
parent_id = sa.Column(sa.Integer, sa.ForeignKey("action_event.id"))

children = sa.orm.relationship("InputEvent")
recording = sa.orm.relationship("Recording", back_populates="input_events")
children = sa.orm.relationship("ActionEvent")
recording = sa.orm.relationship("Recording", back_populates="action_events")
screenshot = sa.orm.relationship("Screenshot")
window_event = sa.orm.relationship("WindowEvent")

Expand Down
Loading

0 comments on commit 7e537d4

Please sign in to comment.