Skip to content

Commit

Permalink
Merge pull request #25 from MLDSAI/feat/ascii
Browse files Browse the repository at this point in the history
add ascii_mixin.py
  • Loading branch information
abrichr authored Apr 19, 2023
2 parents 1b84542 + 5c09b2d commit cb28427
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 47 deletions.
2 changes: 1 addition & 1 deletion puterbot/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.naive import NaiveReplayStrategy
from puterbot.strategies.llm_ocr_demo import LLMOCRDemoReplayStrategy
from puterbot.strategies.demo import DemoReplayStrategy
# add more strategies here
45 changes: 45 additions & 0 deletions puterbot/strategies/ascii_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Implements a ReplayStrategy mixin for converting images to ASCII.
Usage:
class MyReplayStrategy(ASCIIReplayStrategyMixin):
...
"""

from ascii_magic import AsciiArt
from loguru import logger

from puterbot.models import Recording, Screenshot
from puterbot.strategies.base import BaseReplayStrategy


COLUMNS = 120
WIDTH_RATIO = 2.2
MONOCHROME = True


class ASCIIReplayStrategyMixin(BaseReplayStrategy):

def __init__(
self,
recording: Recording,
):
super().__init__(recording)

def get_ascii_text(
self,
screenshot: Screenshot,
monochrome: bool = MONOCHROME,
columns: int = COLUMNS,
width_ratio: float = WIDTH_RATIO,

):
ascii_art = AsciiArt.from_pillow_image(screenshot.image)
ascii_text = ascii_art.to_ascii(
monochrome=monochrome,
columns=columns,
width_ratio=width_ratio,
)
logger.debug(f"ascii_text=\n{ascii_text}")
return ascii_text
57 changes: 57 additions & 0 deletions puterbot/strategies/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Demonstration of LLM, OCR, and ASCII ReplayStrategyMixins.
Usage:
$ python puterbot/replay.py DemoReplayStrategy
"""

from loguru import logger
import numpy as np

from puterbot.events import get_events
from puterbot.models import Recording, Screenshot
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.llm_mixin import LLMReplayStrategyMixin
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin
from puterbot.strategies.ascii_mixin import ASCIIReplayStrategyMixin


class DemoReplayStrategy(
LLMReplayStrategyMixin,
OCRReplayStrategyMixin,
ASCIIReplayStrategyMixin,
BaseReplayStrategy,
):

def __init__(
self,
recording: Recording,
):
super().__init__(recording)

def get_next_input_event(
self,
screenshot: Screenshot,
):
ascii_text = self.get_ascii_text(screenshot)
logger.info(f"ascii_text=\n{ascii_text}")

ocr_text = self.get_ocr_text(screenshot)
logger.info(f"ocr_text=\n{ocr_text}")

max_tokens = 2

_min = 2
_max = 10
tokens = np.random.permutation(
["click"] * np.random.randint(_min, _max) +
["type"] * np.random.randint(_min, _max) +
["scroll"] * np.random.randint(_min, _max)
)
prompt = " ".join(tokens)
logger.info(f"{prompt=}")
completion = self.get_completion(prompt, max_tokens)
logger.info(f"{completion=}")

return None
9 changes: 3 additions & 6 deletions puterbot/strategies/llm_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.model = tf.AutoModelForCausalLM.from_pretrained(model_name)
self.model_max_length = model_max_length

def generate_completion(
def get_completion(
self,
prompt: str,
max_tokens: int,
Expand All @@ -45,24 +45,21 @@ def generate_completion(
f"Truncating from {len(prompt)=} to {model_max_length=}"
)
prompt = prompt[:model_max_length]
logger.info(f"{prompt=} {max_tokens=}")
logger.debug(f"{prompt=} {max_tokens=}")
input_tokens = self.tokenizer(prompt, return_tensors="pt")
pad_token_id = self.tokenizer.eos_token_id
attention_mask = input_tokens["attention_mask"]

output_tokens = self.model.generate(
input_ids=input_tokens["input_ids"],
attention_mask=attention_mask,
max_length=input_tokens["input_ids"].shape[-1] + max_tokens,
pad_token_id=pad_token_id,
num_return_sequences=1
)

N = input_tokens["input_ids"].shape[-1]
completion = self.tokenizer.decode(
output_tokens[:, N:][0],
clean_up_tokenization_spaces=True,
)
logger.info(f"{completion=}")

logger.debug(f"{completion=}")
return completion
38 changes: 0 additions & 38 deletions puterbot/strategies/llm_ocr_demo.py

This file was deleted.

5 changes: 3 additions & 2 deletions puterbot/strategies/ocr_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

self.ocr = RapidOCR()

def get_text(
def get_ocr_text(
self,
screenshot: Screenshot
):
Expand All @@ -48,6 +48,7 @@ def get_text(
logger.debug(f"{elapse=}")
df_text = get_text_df(result)
text = get_text_from_df(df_text)
logger.debug(f"{text=}")
return text


Expand Down Expand Up @@ -79,7 +80,7 @@ def get_text_df(

confidences = [confidence for coords, text, confidence in result]
df["confidence"] = confidences
logger.info(f"df=\n{df}")
logger.debug(f"df=\n{df}")
return df


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
alembic==1.8.1
ascii_magic==2.3.0
bokeh==2.4.3
clipboard==0.0.4
deepdiff==6.2.2
Expand Down

0 comments on commit cb28427

Please sign in to comment.