Skip to content

Commit

Permalink
add llm_mixin.py; ocr_mixin.py; llm_ocr_demo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Apr 19, 2023
1 parent 601208c commit 12f0ff0
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 9 deletions.
9 changes: 7 additions & 2 deletions puterbot/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ def replay(

strategy_class_by_name = get_strategy_class_by_name()
if strategy_name not in strategy_class_by_name:
available_strategy_names = ", ".join(strategy_class_by_name.keys())
strategy_names = [
name
for name in strategy_class_by_name.keys()
if not name.lower().endswith("mixin")
]
available_strategies = ", ".join(strategy_names)
raise ValueError(
f"Invalid {strategy_name=}; {available_strategy_names=}"
f"Invalid {strategy_name=}; {available_strategies=}"
)

strategy_class = strategy_class_by_name[strategy_name]
Expand Down
1 change: 1 addition & 0 deletions puterbot/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.naive import NaiveReplayStrategy
from puterbot.strategies.llm_ocr_demo import LLMOCRDemoReplayStrategy
# add more strategies here
56 changes: 56 additions & 0 deletions puterbot/strategies/llm_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Implements a ReplayStrategy mixin for generating LLM completions.
"""


from loguru import logger
import transformers as tf # RIP TensorFlow

from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy


MODEL_NAME = "gpt2-xl" # gpt2-xl
MODEL_MAX_LENGTH = 1024


class LLMReplayStrategyMixin(BaseReplayStrategy):

def __init__(
self,
recording: Recording,
model_name=MODEL_NAME,
):
super().__init__(recording)

logger.info(f"{model_name=}")
self.tokenizer = tf.AutoTokenizer.from_pretrained(model_name)
self.model = tf.AutoModelForCausalLM.from_pretrained(model_name)

def generate_completion(self, prompt, max_tokens):
if MODEL_MAX_LENGTH and len(prompt) > MODEL_MAX_LENGTH:
logger.warning(
f"Truncating from {len(prompt)=} to {MODEL_MAX_LENGTH=}"
)
prompt = prompt[:MODEL_MAX_LENGTH]
logger.info(f"{prompt=} {max_tokens=}")
input_tokens = self.tokenizer(prompt, return_tensors="pt")
pad_token_id = self.tokenizer.eos_token_id
attention_mask = input_tokens["attention_mask"]

output_tokens = self.model.generate(
input_ids=input_tokens["input_ids"],
attention_mask=attention_mask,
max_length=input_tokens["input_ids"].shape[-1] + max_tokens,
pad_token_id=pad_token_id,
num_return_sequences=1
)

N = input_tokens["input_ids"].shape[-1]
completion = self.tokenizer.decode(
output_tokens[:, N:][0],
clean_up_tokenization_spaces=True,
)
logger.info(f"{completion=}")

return completion
38 changes: 38 additions & 0 deletions puterbot/strategies/llm_ocr_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Implements a ReplayStrategy mixin for generating LLM completions.
"""


from loguru import logger

import mss.base

from puterbot.events import get_events
from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy
from puterbot.strategies.llm_mixin import LLMReplayStrategyMixin
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin


class LLMOCRDemoReplayStrategy(
LLMReplayStrategyMixin,
OCRReplayStrategyMixin,
BaseReplayStrategy,
):

def __init__(
self,
recording: Recording,
):
super().__init__(recording)

def get_next_input_event(
self,
screenshot: mss.base.ScreenShot,
):
text = self.get_text(screenshot)

# this doesn't make sense and is for demonstrative purposes only
completion = self.generate_completion(text, 100)

return None
9 changes: 2 additions & 7 deletions puterbot/strategies/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
from loguru import logger
import mss.base

from puterbot.events import (
get_events,
)
from puterbot.utils import (
display_event,
rows2dicts,
)
from puterbot.events import get_events
from puterbot.utils import display_event, rows2dicts
from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy

Expand Down
161 changes: 161 additions & 0 deletions puterbot/strategies/ocr_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
Implements a ReplayStrategy mixin for getting text from images via OCR
"""

import itertools

from loguru import logger
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from sklearn.cluster import DBSCAN
import mss.base
import numpy as np
import pandas as pd

from puterbot.models import Recording
from puterbot.strategies.base import BaseReplayStrategy


# TODO: use group into sections via layout analysis; see:
# https://github.com/RapidAI/RapidOCR/blob/main/python/rapid_structure/docs/README_Layout.md


class OCRReplayStrategyMixin(BaseReplayStrategy):
def __init__(
self,
recording: Recording,
):
super().__init__(recording)

# https://github.com/RapidAI/RapidOCR/blob/main/python/README.md
self.ocr = RapidOCR()

def get_text(self, screenshot: mss.base.ScreenShot):
image = Image.frombytes(
"RGB", screenshot.size, screenshot.bgra, "raw", "BGRX"
)
arr = np.array(image)
result, elapse = self.ocr(arr)
#det_elapse, cls_elapse, rec_elapse = elapse
#all_elapse = det_elapse + cls_elapse + rec_elapse
logger.debug(f"{result=}")
logger.debug(f"{elapse=}")
df = get_df(result)
text = convert_dataframe_to_string(df)
return text


def unnest(df, explode, axis, suffixes=None):
# https://stackoverflow.com/a/53218939
if axis == 1:
df1 = pd.concat([df[x].explode() for x in explode], axis=1)
return df1.join(df.drop(explode, axis=1), how="left")
else:
df1 = pd.concat(
[
pd.DataFrame(
df[x].tolist(),
index=df.index,
columns=suffixes,
).add_prefix(x)
for x in explode
],
axis=1,
)
return df1.join(
df.drop(explode, axis=1),
how="left",
)


def get_df(result):
"""
Convert RapidOCR result to DataFrame.
Args:
result: list of [coordinates, text, confidence], where
coordinates is itself a list of:
[tl_x, tl_y],
[tr_x, tr_y],
[br_x, br_y],
[bl_x, bl_y]
Returns:
pd.DataFrame
"""

coords = [coords for coords, text, confidence in result]
columns = ["tl", "tr", "bl", "br"]
df = pd.DataFrame(coords, columns=columns)
df = unnest(df, df.columns, 0, suffixes=["_x", "_y"])

texts = [text for coords, text, confidence in result]
df["text"] = texts

confidences = [confidence for coords, text, confidence in result]
df["confidence"] = confidences
logger.info(f"df=\n{df}")
return df


def preprocess_text(text):
return text.strip()


def calculate_centroid(row):
x = (row["tl_x"] + row["tr_x"] + row["bl_x"] + row["br_x"]) / 4
y = (row["tl_y"] + row["tr_y"] + row["bl_y"] + row["br_y"]) / 4
return x, y


def calculate_height(row):
return abs(row["tl_y"] - row["bl_y"])


def sort_rows(df):
df["centroid"] = df.apply(calculate_centroid, axis=1)
df["x"] = df["centroid"].apply(lambda coord: coord[0])
df["y"] = df["centroid"].apply(lambda coord: coord[1])
df.sort_values(by=["y", "x"], inplace=True)
return df


def cluster_lines(df, eps):
coords = df[["x", "y"]].to_numpy()
cluster_model = DBSCAN(eps=eps, min_samples=1)
df["line_cluster"] = cluster_model.fit_predict(coords)
return df


def cluster_words(df):
line_dfs = []
for line_cluster in df["line_cluster"].unique():
line_df = df[df["line_cluster"] == line_cluster].copy()

if len(line_df) > 1:
coords = line_df[["x", "y"]].to_numpy()
eps = line_df["height"].min()
cluster_model = DBSCAN(eps=eps, min_samples=1)
line_df["word_cluster"] = cluster_model.fit_predict(coords)
else:
line_df["word_cluster"] = 0

line_dfs.append(line_df)
return pd.concat(line_dfs)


def concat_text(df):
df.sort_values(by=["line_cluster", "word_cluster"], inplace=True)
lines = df.groupby("line_cluster")["text"].apply(lambda x: " ".join(x))
return "\n".join(lines)


def convert_dataframe_to_string(df):
df["text"] = df["text"].apply(preprocess_text)
sorted_df = sort_rows(df)
df["height"] = df.apply(calculate_height, axis=1)
eps = df["height"].min()
line_clustered_df = cluster_lines(sorted_df, eps)
word_clustered_df = cluster_words(line_clustered_df)
result = concat_text(word_clustered_df)
return result
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pyinstaller
pywin32==306
git+https://github.com/abrichr/pynput.git
pytest==7.1.3
rapidocr-onnxruntime==1.2.3
scikit-learn==1.2.2
scipy==1.9.3
sqlalchemy==1.4.43
torch==2.0.0
tqdm==4.64.0

0 comments on commit 12f0ff0

Please sign in to comment.