From f5e935d37eb719f178357159f3a77a24df5aef5e Mon Sep 17 00:00:00 2001 From: Diian-r Date: Sun, 30 Apr 2023 23:21:23 -0400 Subject: [PATCH] Used the OCR text to note in the prompt to add a note at the front of the string to indicate whether the screen changed --- puterbot/strategies/demo.py | 27 ++++++++++++++++++++++++++- puterbot/strategies/llm_mixin.py | 6 +++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/puterbot/strategies/demo.py b/puterbot/strategies/demo.py index 176a714d1..f2893f4a7 100644 --- a/puterbot/strategies/demo.py +++ b/puterbot/strategies/demo.py @@ -19,6 +19,9 @@ from puterbot.strategies.ocr_mixin import OCRReplayStrategyMixin from puterbot.strategies.ascii_mixin import ASCIIReplayStrategyMixin +from gensim.summarization.summarizer import summarize +from nltk.corpus import wordnet + class DemoReplayStrategy( LLMReplayStrategyMixin, @@ -44,6 +47,28 @@ def get_next_input_event( ocr_text = self.get_ocr_text(screenshot) #logger.info(f"ocr_text=\n{ocr_text}") + # identify what the window is + summarized_ocr = summarize(ocr_text, word_count=1, split=False) + + index = self.screenshots.index(screenshot) + if index != 0: + last_screenshot = self.screenshots[index - 1] + last_summarized_ocr = summarize(last_screenshot, word_count=1, split=False) + + # check if the last screenshot and current screenshot picture the same image + synonyms_screenshot = set(wordnet.synsets(summarized_ocr)) + synonyms_last_screenshot = set(wordnet.synsets(last_summarized_ocr)) + + common_synonyms = list(synonyms_screenshot & synonyms_last_screenshot) + + # may want to change the number of required common synonyms + if len(common_synonyms) > 0: + window_changed = "True" + else: + window_changed = "False" + else: + window_changed = "False" + event_strs = [ f"<{event}>" for event in self.recording.input_events @@ -52,7 +77,7 @@ def get_next_input_event( f"<{completion}>" for completion in self.result_history ] - prompt = " ".join(event_strs + history_strs) + prompt = " ".join(event_strs + history_strs + summarized_ocr) N = max(0, len(prompt) - MAX_INPUT_SIZE) prompt = prompt[N:] logger.info(f"{prompt=}") diff --git a/puterbot/strategies/llm_mixin.py b/puterbot/strategies/llm_mixin.py index d88864e0f..4fb5fd762 100644 --- a/puterbot/strategies/llm_mixin.py +++ b/puterbot/strategies/llm_mixin.py @@ -40,13 +40,13 @@ def get_completion( max_tokens: int, ): max_input_size = self.max_input_size - if max_input_size and len(prompt) > max_input_size: + if max_input_size and len(prompt) - 1 > max_input_size: logger.warning( - f"Truncating from {len(prompt)=} to {max_input_size=}" + f"Truncating from {len(prompt) - 1=} to {max_input_size=}" ) prompt = prompt[max_input_size:] logger.debug(f"{prompt=} {max_tokens=}") - input_tokens = self.tokenizer(prompt, return_tensors="pt") + input_tokens = self.tokenizer(prompt[1:], return_tensors="pt") pad_token_id = self.tokenizer.eos_token_id attention_mask = input_tokens["attention_mask"] output_tokens = self.model.generate(