Skip to content

Commit

Permalink
Used the OCR text to note in the prompt to add a note at the front of…
Browse files Browse the repository at this point in the history
… the string to indicate whether the screen changed
  • Loading branch information
dianzrong committed May 1, 2023
1 parent 02cfb5a commit f5e935d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
27 changes: 26 additions & 1 deletion puterbot/strategies/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin
from puterbot.strategies.ascii_mixin import ASCIIReplayStrategyMixin

from gensim.summarization.summarizer import summarize
from nltk.corpus import wordnet


class DemoReplayStrategy(
LLMReplayStrategyMixin,
Expand All @@ -44,6 +47,28 @@ def get_next_input_event(
ocr_text = self.get_ocr_text(screenshot)
#logger.info(f"ocr_text=\n{ocr_text}")

# identify what the window is
summarized_ocr = summarize(ocr_text, word_count=1, split=False)

index = self.screenshots.index(screenshot)
if index != 0:
last_screenshot = self.screenshots[index - 1]
last_summarized_ocr = summarize(last_screenshot, word_count=1, split=False)

# check if the last screenshot and current screenshot picture the same image
synonyms_screenshot = set(wordnet.synsets(summarized_ocr))
synonyms_last_screenshot = set(wordnet.synsets(last_summarized_ocr))

common_synonyms = list(synonyms_screenshot & synonyms_last_screenshot)

# may want to change the number of required common synonyms
if len(common_synonyms) > 0:
window_changed = "True"
else:
window_changed = "False"
else:
window_changed = "False"

event_strs = [
f"<{event}>"
for event in self.recording.input_events
Expand All @@ -52,7 +77,7 @@ def get_next_input_event(
f"<{completion}>"
for completion in self.result_history
]
prompt = " ".join(event_strs + history_strs)
prompt = " ".join(event_strs + history_strs + summarized_ocr)
N = max(0, len(prompt) - MAX_INPUT_SIZE)
prompt = prompt[N:]
logger.info(f"{prompt=}")
Expand Down
6 changes: 3 additions & 3 deletions puterbot/strategies/llm_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def get_completion(
max_tokens: int,
):
max_input_size = self.max_input_size
if max_input_size and len(prompt) > max_input_size:
if max_input_size and len(prompt) - 1 > max_input_size:
logger.warning(
f"Truncating from {len(prompt)=} to {max_input_size=}"
f"Truncating from {len(prompt) - 1=} to {max_input_size=}"
)
prompt = prompt[max_input_size:]
logger.debug(f"{prompt=} {max_tokens=}")
input_tokens = self.tokenizer(prompt, return_tensors="pt")
input_tokens = self.tokenizer(prompt[1:], return_tensors="pt")
pad_token_id = self.tokenizer.eos_token_id
attention_mask = input_tokens["attention_mask"]
output_tokens = self.model.generate(
Expand Down

0 comments on commit f5e935d

Please sign in to comment.