Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed May 17, 2023
1 parent f92f362 commit 163308b
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 9 deletions.
38 changes: 38 additions & 0 deletions puterbot/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from functools import wraps
import time

from joblib import Memory
from loguru import logger


from puterbot import config


def default(val, default):
return val if val is not None else default


def cache(dir_path=None, enabled=None, verbosity=None, **cache_kwargs):
"""TODO"""

cache_dir_path = default(dir_path, config.CACHE_DIR_PATH)
cache_enabled = default(enabled, config.CACHE_ENABLED)
cache_verbosity = default(verbosity, config.CACHE_VERBOSITY)
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
logger.debug(f"{cache_enabled=}")
if cache_enabled:
memory = Memory(cache_dir_path, verbose=cache_verbosity)
nonlocal fn
fn = memory.cache(fn, **cache_kwargs)
cache_hit = fn.check_call_in_cache(*args, **kwargs)
logger.debug(f"{fn=} {cache_hit=}")
start_time = time.time()
logger.debug(f"{fn=} {start_time=}")
rval = fn(*args, **kwargs)
duration = time.time() - start_time
logger.debug(f"{fn=} {duration=}")
return rval
return wrapper
return decorator
3 changes: 3 additions & 0 deletions puterbot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


_DEFAULTS = {
"CACHE_DIR_PATH": ".cache",
"CACHE_ENABLED": True,
"CACHE_VERBOSITY": 1,
"DB_ECHO": False,
"DB_FNAME": "openadapt.db",
"OPENAI_API_KEY": None,
Expand Down
26 changes: 25 additions & 1 deletion puterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ class Recording(db.Base):
back_populates="recording",
order_by="ActionEvent.timestamp",
)
screenshots = sa.orm.relationship(
"Screenshot",
back_populates="recording",
order_by="Screenshot.timestamp",
)
window_events = sa.orm.relationship(
"WindowEvent",
back_populates="recording",
order_by="WindowEvent.timestamp",
)

_processed_action_events = None

@property
def processed_action_events(self):
from puterbot import events
if not self._processed_action_events:
self._processed_action_events = events.get_events(self)
return self._processed_action_events



class ActionEvent(db.Base):
Expand Down Expand Up @@ -164,11 +184,13 @@ class Screenshot(db.Base):
recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp"))
timestamp = sa.Column(sa.DateTime)
png_data = sa.Column(sa.LargeBinary)
# TODO: replace prev with prev_timestamp?

recording = sa.orm.relationship("Recording", back_populates="screenshots")

# TODO: convert to png_data on save
sct_img = None

# TODO: replace prev with prev_timestamp?
prev = None
_image = None
_diff = None
Expand Down Expand Up @@ -228,6 +250,8 @@ class WindowEvent(db.Base):
width = sa.Column(sa.Integer)
height = sa.Column(sa.Integer)

recording = sa.orm.relationship("Recording", back_populates="window_events")

@classmethod
def get_active_window_state(cls):
return window.get_active_window_state()
1 change: 1 addition & 0 deletions puterbot/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.naive import NaiveReplayStrategy
from puterbot.strategies.demo import DemoReplayStrategy
from puterbot.strategies.stateful import StatefulReplayStrategy
# add more strategies here
1 change: 1 addition & 0 deletions puterbot/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
):
self.recording = recording
self.max_frame_times = max_frame_times
self.processed_action_events = recording.processed_action_events
self.action_events = []
self.screenshots = []
self.window_states = []
Expand Down
5 changes: 3 additions & 2 deletions puterbot/strategies/mixins/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ class MyReplayStrategy(OpenAIReplayStrategyMixin):
...
"""

import openai
import tiktoken

from loguru import logger
from puterbot.strategies.base import BaseReplayStrategy
import openai import tiktoken

from puterbot.strategies.base import BaseReplayStrategy
from puterbot import cache, config, models


Expand Down
4 changes: 0 additions & 4 deletions puterbot/strategies/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def __init__(
self.sleep = sleep
self.prev_timestamp = None
self.action_event_idx = -1
self.processed_action_events = events.get_events(
recording,
process=PROCESS_EVENTS,
)
event_dicts = utils.rows2dicts(self.processed_action_events)
logger.info(f"event_dicts=\n{pformat(event_dicts)}")

Expand Down
27 changes: 25 additions & 2 deletions puterbot/strategies/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
"""

from loguru import logger
import deepdiff
import numpy as np

from puterbot import events, models, strategies
)
from puterbot.strategies.mixins.openai import OpenAIReplayStrategyMixin


class StatefulReplayStrategy(
strategies.mixins.openai.OpenAIReplayStrategyMixin,
OpenAIReplayStrategyMixin,
strategies.base.BaseReplayStrategy,
):

Expand All @@ -28,6 +29,7 @@ def __init__(
def get_next_action_event(
self,
screenshot: models.Screenshot,
window_event: models.WindowEvent,
):
event_strs = [
f"<{event}>"
Expand All @@ -37,6 +39,11 @@ def get_next_action_event(
f"<{completion}>"
for completion in self.result_history
]

state_diffs = get_state_diffs(self.processed_action_events)
# TODO XXX

"""
prompt = " ".join(event_strs + history_strs)
N = max(0, len(prompt) - MAX_INPUT_SIZE)
prompt = prompt[N:]
Expand All @@ -53,3 +60,19 @@ def get_next_action_event(
# TODO: parse result into ActionEvent(s)
return None
"""


def get_state_diffs(action_events):
window_events = [
action_event.window_event
for action_event in action_events
]
diffs = [
deepdiff.DeepDiff(prev_window_event.state, window_event.state)
for prev_window_event, window_event in zip(
window_events, window_events[1:]
)
]
return diffs
import ipdb; ipdb.set_trace()

0 comments on commit 163308b

Please sign in to comment.