-
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add llm_mixin.py; ocr_mixin.py; llm_ocr_demo.py
- Loading branch information
Showing
7 changed files
with
268 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
from puterbot.strategies.naive import NaiveReplayStrategy | ||
from puterbot.strategies.llm_ocr_demo import LLMOCRDemoReplayStrategy | ||
# add more strategies here |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Implements a ReplayStrategy mixin for generating LLM completions. | ||
""" | ||
|
||
|
||
from loguru import logger | ||
import transformers as tf # RIP TensorFlow | ||
|
||
from puterbot.models import Recording | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
|
||
|
||
MODEL_NAME = "gpt2-xl" # gpt2-xl | ||
MODEL_MAX_LENGTH = 1024 | ||
|
||
|
||
class LLMReplayStrategyMixin(BaseReplayStrategy): | ||
|
||
def __init__( | ||
self, | ||
recording: Recording, | ||
model_name=MODEL_NAME, | ||
): | ||
super().__init__(recording) | ||
|
||
logger.info(f"{model_name=}") | ||
self.tokenizer = tf.AutoTokenizer.from_pretrained(model_name) | ||
self.model = tf.AutoModelForCausalLM.from_pretrained(model_name) | ||
|
||
def generate_completion(self, prompt, max_tokens): | ||
if MODEL_MAX_LENGTH and len(prompt) > MODEL_MAX_LENGTH: | ||
logger.warning( | ||
f"Truncating from {len(prompt)=} to {MODEL_MAX_LENGTH=}" | ||
) | ||
prompt = prompt[:MODEL_MAX_LENGTH] | ||
logger.info(f"{prompt=} {max_tokens=}") | ||
input_tokens = self.tokenizer(prompt, return_tensors="pt") | ||
pad_token_id = self.tokenizer.eos_token_id | ||
attention_mask = input_tokens["attention_mask"] | ||
|
||
output_tokens = self.model.generate( | ||
input_ids=input_tokens["input_ids"], | ||
attention_mask=attention_mask, | ||
max_length=input_tokens["input_ids"].shape[-1] + max_tokens, | ||
pad_token_id=pad_token_id, | ||
num_return_sequences=1 | ||
) | ||
|
||
N = input_tokens["input_ids"].shape[-1] | ||
completion = self.tokenizer.decode( | ||
output_tokens[:, N:][0], | ||
clean_up_tokenization_spaces=True, | ||
) | ||
logger.info(f"{completion=}") | ||
|
||
return completion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Implements a ReplayStrategy mixin for generating LLM completions. | ||
""" | ||
|
||
|
||
from loguru import logger | ||
|
||
import mss.base | ||
|
||
from puterbot.events import get_events | ||
from puterbot.models import Recording | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
from puterbot.strategies.llm_mixin import LLMReplayStrategyMixin | ||
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin | ||
|
||
|
||
class LLMOCRDemoReplayStrategy( | ||
LLMReplayStrategyMixin, | ||
OCRReplayStrategyMixin, | ||
BaseReplayStrategy, | ||
): | ||
|
||
def __init__( | ||
self, | ||
recording: Recording, | ||
): | ||
super().__init__(recording) | ||
|
||
def get_next_input_event( | ||
self, | ||
screenshot: mss.base.ScreenShot, | ||
): | ||
text = self.get_text(screenshot) | ||
|
||
# this doesn't make sense and is for demonstrative purposes only | ||
completion = self.generate_completion(text, 100) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
""" | ||
Implements a ReplayStrategy mixin for getting text from images via OCR | ||
""" | ||
|
||
import itertools | ||
|
||
from loguru import logger | ||
from PIL import Image | ||
from rapidocr_onnxruntime import RapidOCR | ||
from sklearn.cluster import DBSCAN | ||
import mss.base | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from puterbot.models import Recording | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
|
||
|
||
# TODO: use group into sections via layout analysis; see: | ||
# https://github.com/RapidAI/RapidOCR/blob/main/python/rapid_structure/docs/README_Layout.md | ||
|
||
|
||
class OCRReplayStrategyMixin(BaseReplayStrategy): | ||
def __init__( | ||
self, | ||
recording: Recording, | ||
): | ||
super().__init__(recording) | ||
|
||
# https://github.com/RapidAI/RapidOCR/blob/main/python/README.md | ||
self.ocr = RapidOCR() | ||
|
||
def get_text(self, screenshot: mss.base.ScreenShot): | ||
image = Image.frombytes( | ||
"RGB", screenshot.size, screenshot.bgra, "raw", "BGRX" | ||
) | ||
arr = np.array(image) | ||
result, elapse = self.ocr(arr) | ||
#det_elapse, cls_elapse, rec_elapse = elapse | ||
#all_elapse = det_elapse + cls_elapse + rec_elapse | ||
logger.debug(f"{result=}") | ||
logger.debug(f"{elapse=}") | ||
df = get_df(result) | ||
text = convert_dataframe_to_string(df) | ||
return text | ||
|
||
|
||
def unnest(df, explode, axis, suffixes=None): | ||
# https://stackoverflow.com/a/53218939 | ||
if axis == 1: | ||
df1 = pd.concat([df[x].explode() for x in explode], axis=1) | ||
return df1.join(df.drop(explode, axis=1), how="left") | ||
else: | ||
df1 = pd.concat( | ||
[ | ||
pd.DataFrame( | ||
df[x].tolist(), | ||
index=df.index, | ||
columns=suffixes, | ||
).add_prefix(x) | ||
for x in explode | ||
], | ||
axis=1, | ||
) | ||
return df1.join( | ||
df.drop(explode, axis=1), | ||
how="left", | ||
) | ||
|
||
|
||
def get_df(result): | ||
""" | ||
Convert RapidOCR result to DataFrame. | ||
Args: | ||
result: list of [coordinates, text, confidence], where | ||
coordinates is itself a list of: | ||
[tl_x, tl_y], | ||
[tr_x, tr_y], | ||
[br_x, br_y], | ||
[bl_x, bl_y] | ||
Returns: | ||
pd.DataFrame | ||
""" | ||
|
||
coords = [coords for coords, text, confidence in result] | ||
columns = ["tl", "tr", "bl", "br"] | ||
df = pd.DataFrame(coords, columns=columns) | ||
df = unnest(df, df.columns, 0, suffixes=["_x", "_y"]) | ||
|
||
texts = [text for coords, text, confidence in result] | ||
df["text"] = texts | ||
|
||
confidences = [confidence for coords, text, confidence in result] | ||
df["confidence"] = confidences | ||
logger.info(f"df=\n{df}") | ||
return df | ||
|
||
|
||
def preprocess_text(text): | ||
return text.strip() | ||
|
||
|
||
def calculate_centroid(row): | ||
x = (row["tl_x"] + row["tr_x"] + row["bl_x"] + row["br_x"]) / 4 | ||
y = (row["tl_y"] + row["tr_y"] + row["bl_y"] + row["br_y"]) / 4 | ||
return x, y | ||
|
||
|
||
def calculate_height(row): | ||
return abs(row["tl_y"] - row["bl_y"]) | ||
|
||
|
||
def sort_rows(df): | ||
df["centroid"] = df.apply(calculate_centroid, axis=1) | ||
df["x"] = df["centroid"].apply(lambda coord: coord[0]) | ||
df["y"] = df["centroid"].apply(lambda coord: coord[1]) | ||
df.sort_values(by=["y", "x"], inplace=True) | ||
return df | ||
|
||
|
||
def cluster_lines(df, eps): | ||
coords = df[["x", "y"]].to_numpy() | ||
cluster_model = DBSCAN(eps=eps, min_samples=1) | ||
df["line_cluster"] = cluster_model.fit_predict(coords) | ||
return df | ||
|
||
|
||
def cluster_words(df): | ||
line_dfs = [] | ||
for line_cluster in df["line_cluster"].unique(): | ||
line_df = df[df["line_cluster"] == line_cluster].copy() | ||
|
||
if len(line_df) > 1: | ||
coords = line_df[["x", "y"]].to_numpy() | ||
eps = line_df["height"].min() | ||
cluster_model = DBSCAN(eps=eps, min_samples=1) | ||
line_df["word_cluster"] = cluster_model.fit_predict(coords) | ||
else: | ||
line_df["word_cluster"] = 0 | ||
|
||
line_dfs.append(line_df) | ||
return pd.concat(line_dfs) | ||
|
||
|
||
def concat_text(df): | ||
df.sort_values(by=["line_cluster", "word_cluster"], inplace=True) | ||
lines = df.groupby("line_cluster")["text"].apply(lambda x: " ".join(x)) | ||
return "\n".join(lines) | ||
|
||
|
||
def convert_dataframe_to_string(df): | ||
df["text"] = df["text"].apply(preprocess_text) | ||
sorted_df = sort_rows(df) | ||
df["height"] = df.apply(calculate_height, axis=1) | ||
eps = df["height"].min() | ||
line_clustered_df = cluster_lines(sorted_df, eps) | ||
word_clustered_df = cluster_words(line_clustered_df) | ||
result = concat_text(word_clustered_df) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters