Skip to content

Commit

Permalink
Merge pull request #23 from MLDSAI/refactor/screenshots
Browse files Browse the repository at this point in the history
refactor screenshots
  • Loading branch information
abrichr authored Apr 19, 2023
2 parents 09e36f8 + 0657e43 commit da13e5f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
28 changes: 26 additions & 2 deletions puterbot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from loguru import logger
from pynput import keyboard
from PIL import Image, ImageChops
import numpy as np
import sqlalchemy as sa

from puterbot.db import Base
from puterbot.utils import take_screenshot


class Recording(Base):
Expand Down Expand Up @@ -136,6 +138,9 @@ class Screenshot(Base):
png_data = sa.Column(sa.LargeBinary)
# TODO: replace prev with prev_timestamp?

# TODO: convert to png_data on save
sct_img = None

prev = None
_image = None
_diff = None
Expand All @@ -144,8 +149,17 @@ class Screenshot(Base):
@property
def image(self):
if not self._image:
buffer = io.BytesIO(self.png_data)
self._image = Image.open(buffer)
if self.sct_img:
self._image = Image.frombytes(
"RGB",
self.sct_img.size,
self.sct_img.bgra,
"raw",
"BGRX",
)
else:
buffer = io.BytesIO(self.png_data)
self._image = Image.open(buffer)
return self._image

@property
Expand All @@ -161,6 +175,16 @@ def diff_mask(self):
self._diff_mask = self._diff.convert("1")
return self._diff_mask

@property
def array(self):
return np.array(self.image)

@classmethod
def take_screenshot(cls):
sct_img = take_screenshot()
screenshot = Screenshot(sct_img=sct_img)
return screenshot


class WindowEvent(Base):
__tablename__ = "window_event"
Expand Down
9 changes: 4 additions & 5 deletions puterbot/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import mss.base
import numpy as np

from puterbot.models import Recording, InputEvent
from puterbot.models import InputEvent, Recording, Screenshot
from puterbot.playback import play_input_event
from puterbot.utils import get_screenshot


MAX_FRAME_TIMES = 1000
Expand All @@ -34,15 +33,15 @@ def __init__(
@abstractmethod
def get_next_input_event(
self,
screenshot: mss.base.ScreenShot,
screenshot: Screenshot,
) -> InputEvent:
pass

def run(self):
keyboard_controller = keyboard.Controller()
mouse_controller = mouse.Controller()
while True:
screenshot = get_screenshot()
screenshot = Screenshot.take_screenshot()
self.screenshots.append(screenshot)
try:
input_event = self.get_next_input_event(screenshot)
Expand All @@ -63,7 +62,7 @@ def log_fps(self):
dts = np.diff(self.frame_times)
if len(dts) > 1:
mean_dt = np.mean(dts)
fps = len(dts) / mean_dt
fps = 1 / mean_dt
logger.info(f"{fps=:.2f}")
if len(self.frame_times) > self.max_frame_times:
self.frame_times.pop(0)
11 changes: 3 additions & 8 deletions puterbot/strategies/ocr_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ class MyReplayStrategy(OCRReplayStrategyMixin):
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from sklearn.cluster import DBSCAN
import mss.base
import numpy as np
import pandas as pd

from puterbot.models import Recording
from puterbot.models import Recording, Screenshot
from puterbot.strategies.base import BaseReplayStrategy


Expand All @@ -39,14 +38,10 @@ def __init__(

def get_text(
self,
screenshot: mss.base.ScreenShot
screenshot: Screenshot
):
# TOOD: improve performance
image = Image.frombytes(
"RGB", screenshot.size, screenshot.bgra, "raw", "BGRX"
)
arr = np.array(image)
result, elapse = self.ocr(arr)
result, elapse = self.ocr(screenshot.array)
#det_elapse, cls_elapse, rec_elapse = elapse
#all_elapse = det_elapse + cls_elapse + rec_elapse
logger.debug(f"{result=}")
Expand Down
6 changes: 3 additions & 3 deletions puterbot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,12 @@ def evenly_spaced(arr, N):
return [val for idx, val in enumerate(arr) if idx in idxs]


def get_screenshot() -> mss.base.ScreenShot:
def take_screenshot() -> mss.base.ScreenShot:
with mss.mss() as sct:
# monitor 0 is all in one
monitor = sct.monitors[0]
screenshot = sct.grab(monitor)
return screenshot
sct_img = sct.grab(monitor)
return sct_img


def get_strategy_class_by_name():
Expand Down

0 comments on commit da13e5f

Please sign in to comment.