-
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from MLDSAI/feat/llm_ocr_demo
add llm_ocr_demo.py and related mixins
- Loading branch information
Showing
9 changed files
with
334 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
from puterbot.strategies.naive import NaiveReplayStrategy | ||
from puterbot.strategies.llm_ocr_demo import LLMOCRDemoReplayStrategy | ||
# add more strategies here |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Implements a ReplayStrategy mixin for generating LLM completions. | ||
Usage: | ||
class MyReplayStrategy(LLMReplayStrategyMixin): | ||
... | ||
""" | ||
|
||
|
||
from loguru import logger | ||
import transformers as tf # RIP TensorFlow | ||
|
||
from puterbot.models import Recording | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
|
||
|
||
MODEL_NAME = "gpt2" # gpt2-xl is bigger and slower | ||
MODEL_MAX_LENGTH = 1024 | ||
|
||
|
||
class LLMReplayStrategyMixin(BaseReplayStrategy): | ||
|
||
def __init__( | ||
self, | ||
recording: Recording, | ||
model_name: str = MODEL_NAME, | ||
model_max_length: str = MODEL_MAX_LENGTH, | ||
): | ||
super().__init__(recording) | ||
|
||
logger.info(f"{model_name=}") | ||
self.tokenizer = tf.AutoTokenizer.from_pretrained(model_name) | ||
self.model = tf.AutoModelForCausalLM.from_pretrained(model_name) | ||
self.model_max_length = model_max_length | ||
|
||
def generate_completion( | ||
self, | ||
prompt: str, | ||
max_tokens: int, | ||
): | ||
model_max_length = self.model_max_length | ||
if model_max_length and len(prompt) > model_max_length: | ||
logger.warning( | ||
f"Truncating from {len(prompt)=} to {model_max_length=}" | ||
) | ||
prompt = prompt[:model_max_length] | ||
logger.info(f"{prompt=} {max_tokens=}") | ||
input_tokens = self.tokenizer(prompt, return_tensors="pt") | ||
pad_token_id = self.tokenizer.eos_token_id | ||
attention_mask = input_tokens["attention_mask"] | ||
|
||
output_tokens = self.model.generate( | ||
input_ids=input_tokens["input_ids"], | ||
attention_mask=attention_mask, | ||
max_length=input_tokens["input_ids"].shape[-1] + max_tokens, | ||
pad_token_id=pad_token_id, | ||
num_return_sequences=1 | ||
) | ||
|
||
N = input_tokens["input_ids"].shape[-1] | ||
completion = self.tokenizer.decode( | ||
output_tokens[:, N:][0], | ||
clean_up_tokenization_spaces=True, | ||
) | ||
logger.info(f"{completion=}") | ||
|
||
return completion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Demonstration of LLMReplayStrategyMixin and OCRReplayStrategyMixin. | ||
Usage: | ||
$ python puterbot/replay.py LLMOCRDemoReplayStrategy | ||
""" | ||
|
||
|
||
import mss.base | ||
|
||
from puterbot.events import get_events | ||
from puterbot.models import Recording | ||
from puterbot.strategies.base import BaseReplayStrategy | ||
from puterbot.strategies.llm_mixin import LLMReplayStrategyMixin | ||
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin | ||
|
||
|
||
class LLMOCRDemoReplayStrategy( | ||
LLMReplayStrategyMixin, | ||
OCRReplayStrategyMixin, | ||
BaseReplayStrategy, | ||
): | ||
|
||
def __init__( | ||
self, | ||
recording: Recording, | ||
): | ||
super().__init__(recording) | ||
|
||
def get_next_input_event( | ||
self, | ||
screenshot: mss.base.ScreenShot, | ||
): | ||
text = self.get_text(screenshot) | ||
|
||
# N.B. this doesn't make sense and is for demonstration purposes only | ||
completion = self.generate_completion(text, 100) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.