Skip to content

Commit

Permalink
Merge pull request #21 from MLDSAI/feat/llm_ocr_demo
Browse files Browse the repository at this point in the history
add llm_ocr_demo.py and related mixins
  • Loading branch information
abrichr authored Apr 19, 2023
2 parents 601208c + 61764bf commit 09e36f8
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 24 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ may base your implementation off of `naive.py`.

See https://github.com/MLDSAI/puterbot/issues for ideas on where to start.

See `strategies/llm_ocr_demo.py` for example usage of a Large Language Model
and Optical Character Recognition.

### Evaluation Criteria

Your submission will be evaluated based on the following criteria:
Expand Down
23 changes: 9 additions & 14 deletions puterbot/replay.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
from pprint import pformat
import importlib
import time

from loguru import logger
from pynput import keyboard, mouse
import fire

from puterbot.crud import (
get_latest_recording,
)
from puterbot.utils import (
configure_logging,
get_strategy_class_by_name,
)
from puterbot.crud import get_latest_recording
from puterbot.utils import configure_logging, get_strategy_class_by_name


LOG_LEVEL = "INFO"
Expand All @@ -30,9 +20,14 @@ def replay(

strategy_class_by_name = get_strategy_class_by_name()
if strategy_name not in strategy_class_by_name:
available_strategy_names = ", ".join(strategy_class_by_name.keys())
strategy_names = [
name
for name in strategy_class_by_name.keys()
if not name.lower().endswith("mixin")
]
available_strategies = ", ".join(strategy_names)
raise ValueError(
f"Invalid {strategy_name=}; {available_strategy_names=}"
f"Invalid {strategy_name=}; {available_strategies=}"
)

strategy_class = strategy_class_by_name[strategy_name]
Expand Down
1 change: 1 addition & 0 deletions puterbot/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.naive import NaiveReplayStrategy
from puterbot.strategies.llm_ocr_demo import LLMOCRDemoReplayStrategy
# add more strategies here
25 changes: 22 additions & 3 deletions puterbot/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@
"""

from abc import ABC, abstractmethod
import time

from loguru import logger
from pynput import keyboard, mouse
import mss.base
import numpy as np

from puterbot.models import Recording, InputEvent
from puterbot.playback import play_input_event
from puterbot.utils import get_screenshot


MAX_FRAME_TIMES = 1000


class BaseReplayStrategy(ABC):

def __init__(
self,
recording: Recording,
max_frame_times: int = MAX_FRAME_TIMES,
):
self.recording = recording
self.max_frame_times = max_frame_times
self.input_events = []
self.screenshots = []
self.frame_times = []

@abstractmethod
def get_next_input_event(
Expand All @@ -29,9 +38,7 @@ def get_next_input_event(
) -> InputEvent:
pass

def run(
self,
):
def run(self):
keyboard_controller = keyboard.Controller()
mouse_controller = mouse.Controller()
while True:
Expand All @@ -41,10 +48,22 @@ def run(
input_event = self.get_next_input_event(screenshot)
except StopIteration:
break
self.log_fps()
self.input_events.append(input_event)
if input_event:
play_input_event(
input_event,
mouse_controller,
keyboard_controller,
)

def log_fps(self):
t = time.time()
self.frame_times.append(t)
dts = np.diff(self.frame_times)
if len(dts) > 1:
mean_dt = np.mean(dts)
fps = len(dts) / mean_dt
logger.info(f"{fps=:.2f}")
if len(self.frame_times) > self.max_frame_times:
self.frame_times.pop(0)
68 changes: 68 additions & 0 deletions puterbot/strategies/llm_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Implements a ReplayStrategy mixin for generating LLM completions.
Usage:
class MyReplayStrategy(LLMReplayStrategyMixin):
...
"""


from loguru import logger
import transformers as tf # RIP TensorFlow

from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy


MODEL_NAME = "gpt2" # gpt2-xl is bigger and slower
MODEL_MAX_LENGTH = 1024


class LLMReplayStrategyMixin(BaseReplayStrategy):

def __init__(
self,
recording: Recording,
model_name: str = MODEL_NAME,
model_max_length: str = MODEL_MAX_LENGTH,
):
super().__init__(recording)

logger.info(f"{model_name=}")
self.tokenizer = tf.AutoTokenizer.from_pretrained(model_name)
self.model = tf.AutoModelForCausalLM.from_pretrained(model_name)
self.model_max_length = model_max_length

def generate_completion(
self,
prompt: str,
max_tokens: int,
):
model_max_length = self.model_max_length
if model_max_length and len(prompt) > model_max_length:
logger.warning(
f"Truncating from {len(prompt)=} to {model_max_length=}"
)
prompt = prompt[:model_max_length]
logger.info(f"{prompt=} {max_tokens=}")
input_tokens = self.tokenizer(prompt, return_tensors="pt")
pad_token_id = self.tokenizer.eos_token_id
attention_mask = input_tokens["attention_mask"]

output_tokens = self.model.generate(
input_ids=input_tokens["input_ids"],
attention_mask=attention_mask,
max_length=input_tokens["input_ids"].shape[-1] + max_tokens,
pad_token_id=pad_token_id,
num_return_sequences=1
)

N = input_tokens["input_ids"].shape[-1]
completion = self.tokenizer.decode(
output_tokens[:, N:][0],
clean_up_tokenization_spaces=True,
)
logger.info(f"{completion=}")

return completion
40 changes: 40 additions & 0 deletions puterbot/strategies/llm_ocr_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Demonstration of LLMReplayStrategyMixin and OCRReplayStrategyMixin.
Usage:
$ python puterbot/replay.py LLMOCRDemoReplayStrategy
"""


import mss.base

from puterbot.events import get_events
from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.llm_mixin import LLMReplayStrategyMixin
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin


class LLMOCRDemoReplayStrategy(
LLMReplayStrategyMixin,
OCRReplayStrategyMixin,
BaseReplayStrategy,
):

def __init__(
self,
recording: Recording,
):
super().__init__(recording)

def get_next_input_event(
self,
screenshot: mss.base.ScreenShot,
):
text = self.get_text(screenshot)

# N.B. this doesn't make sense and is for demonstration purposes only
completion = self.generate_completion(text, 100)

return None
9 changes: 2 additions & 7 deletions puterbot/strategies/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
from loguru import logger
import mss.base

from puterbot.events import (
get_events,
)
from puterbot.utils import (
display_event,
rows2dicts,
)
from puterbot.events import get_events
from puterbot.utils import display_event, rows2dicts
from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy

Expand Down
Loading

0 comments on commit 09e36f8

Please sign in to comment.