Skip to content

Commit

Permalink
Merge pull request #223 from Mustaballer/share-magic-wormhole
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer authored Dec 12, 2023
2 parents ec58755 + b79f3b5 commit cf81f1a
Show file tree
Hide file tree
Showing 13 changed files with 1,273 additions and 209 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
OPENAI_API_KEY=<set your api key>
DB_FNAME=openadapt.db
35 changes: 32 additions & 3 deletions assets/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,35 @@
-- Insert sample recordings
INSERT INTO recording (timestamp, monitor_width, monitor_height, double_click_interval_seconds, double_click_distance_pixels, platform, task_description)
VALUES
('2023-06-28 10:15:00', 1920, 1080, 0.5, 10, 'Windows', 'Task 1'),
('2023-06-29 14:30:00', 1366, 768, 0.3, 8, 'Mac', 'Task 2'),
('2023-06-30 09:45:00', 2560, 1440, 0.7, 12, 'Linux', 'Task 3');
(1689889605.9053426, 1920, 1080, 0.5, 4, 'win32', 'type l');

-- Insert sample action_events
INSERT INTO action_event (name, timestamp, recording_timestamp, screenshot_timestamp, window_event_timestamp, mouse_x, mouse_y, mouse_dx, mouse_dy, mouse_button_name, mouse_pressed, key_name, key_char, key_vk, canonical_key_name, canonical_key_char, canonical_key_vk, parent_id, element_state)
VALUES
('press', 1690049582.7713714, 1689889605.9053426, 1690049582.7686925, 1690049556.2166219, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'l', '76', NULL, 'l', NULL, NULL, 'null'),
('release', 1690049582.826782, 1689889605.9053426, 1690049582.7686925, 1690049556.2166219, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'l', '76', NULL, 'l', NULL, NULL, 'null');

-- Insert sample screenshots
INSERT INTO screenshot (recording_timestamp, timestamp, png_data)
VALUES
(1689889605.9053426, 1690042711.774856, x'89504E470D0A1A0A0000000D49484452000000010000000108060000009077BF8A0000000A4944415408D7636000000005000000008D2B4233000000000049454E44AE426082');
-- PNG data represents a 1x1 pixel image with a white pixel

-- Insert sample window_events
INSERT INTO window_event (recording_timestamp, timestamp, state, title, left, top, width, height, window_id)
VALUES
(1689889605.9053426, 1690042703.7413292, '{"title": "recording.txt - openadapt - Visual Studio Code", "left": -9, "top": -9, "width": 1938, "height": 1048, "meta": {"class_name": "Chrome_WidgetWin_1", "control_id": 0, "rectangle": {"left": 0, "top": 0, "right": 1920, "bottom": 1030}, "is_visible": true, "is_enabled": true, "control_count": 0}}', 'recording.txt - openadapt - Visual Studio Code', -9, -9, 1938, 1048, '0');

-- Insert sample performance_stats
INSERT INTO performance_stat (recording_timestamp, event_type, start_time, end_time, window_id)
VALUES
(1689889605.9053426, 'action', 1690042703, 1690042711, 1),
(1689889605.9053426, 'action', 1690042712, 1690042720, 1);
-- Add more rows as needed for additional performance_stats

-- Insert sample memory_stats
INSERT INTO memory_stat (recording_timestamp, memory_usage_bytes, timestamp)
VALUES
(1689889605.9053426, 524288, 1690042703),
(1689889605.9053426, 1048576, 1690042711);
-- Add more rows as needed for additional memory_stats
33 changes: 31 additions & 2 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import multiprocessing
import os
import pathlib
import shutil

from dotenv import load_dotenv
from loguru import logger
Expand Down Expand Up @@ -128,6 +129,11 @@
] + SPECIAL_CHAR_STOP_SEQUENCES

ENV_FILE_PATH = ".env"
ENV_EXAMPLE_FILE_PATH = ".env.example"

# Create .env file if it doesn't exist
if not os.path.isfile(ENV_FILE_PATH):
shutil.copy(ENV_EXAMPLE_FILE_PATH, ENV_FILE_PATH)


def getenv_fallback(var_name: str) -> str:
Expand All @@ -137,7 +143,7 @@ def getenv_fallback(var_name: str) -> str:
var_name (str): The name of the environment variable.
Returns:
The value of the environment variable or the fallback default value.
str: The value of the environment variable or the default value if not found.
Raises:
ValueError: If the environment variable is not defined.
Expand Down Expand Up @@ -192,11 +198,34 @@ def persist_env(var_name: str, val: str, env_file_path: str = ENV_FILE_PATH) ->
locals()[key] = val

ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
DB_FPATH = ROOT_DIRPATH / DB_FNAME # type: ignore # noqa
DATA_DIRECTORY_PATH = ROOT_DIRPATH / "data"
RECORDING_DIRECTORY_PATH = DATA_DIRECTORY_PATH / "recordings"
# TODO: clarify why this is necessary (see share.py)
if DB_FNAME == "openadapt.db": # noqa
DB_FPATH = ROOT_DIRPATH / DB_FNAME # noqa
else:
DB_FPATH = RECORDING_DIRECTORY_PATH / DB_FNAME # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
DIRNAME_PERFORMANCE_PLOTS = "performance"


def set_db_url(db_fname: str) -> None:
"""Set the database URL based on the given database file name.
Args:
db_fname (str): The database file name.
"""
# TODO: pass these in as parameters, whose default values are the globals
global DB_FNAME, DB_FPATH, DB_URL
DB_FNAME = db_fname
if DB_FNAME == "openadapt.db": # noqa
DB_FPATH = ROOT_DIRPATH / DB_FNAME # noqa
else:
DB_FPATH = RECORDING_DIRECTORY_PATH / DB_FNAME # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
logger.info(f"{DB_URL=}")


def obfuscate(val: str, pct_reveal: float = 0.1, char: str = "*") -> str:
"""Obfuscates a value by replacing a portion of characters.
Expand Down
43 changes: 43 additions & 0 deletions openadapt/custom_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Module for log message filtering, excluding strings & limiting warnings."""

from collections import defaultdict
import time

from openadapt import config

MESSAGE_TIMESTAMPS = defaultdict(list)

# TODO: move utils.configure_logging to here


def filter_log_messages(data: dict) -> bool:
"""Filter log messages based on the defined criteria.
Args:
data: The log message data from a loguru logger.
Returns:
bool: True if the log message should not be ignored, False otherwise.
"""
# TODO: ultimately, we want to fix the underlying issues, but for now,
# we can ignore these messages
for msg in config.MESSAGES_TO_FILTER:
if msg in data["message"]:
if config.MAX_NUM_WARNINGS_PER_SECOND > 0:
current_timestamp = time.time()
MESSAGE_TIMESTAMPS[msg].append(current_timestamp)
timestamps = MESSAGE_TIMESTAMPS[msg]

# Remove timestamps older than 1 second
timestamps = [
ts
for ts in timestamps
if current_timestamp - ts <= config.WARNING_SUPPRESSION_PERIOD
]

if len(timestamps) > config.MAX_NUM_WARNINGS_PER_SECOND:
return False

MESSAGE_TIMESTAMPS[msg] = timestamps

return True
2 changes: 2 additions & 0 deletions openadapt/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Package for interacting with the OpenAdapt database."""

from .db import export_recording # noqa: F401
9 changes: 9 additions & 0 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def get_latest_recording() -> Recording:
return db.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()


def get_recording_by_id(recording_id: int) -> Recording:
"""Get the recording by an id.
Returns:
Recording: The latest recording object.
"""
return db.query(Recording).filter_by(id=recording_id).first()


def get_recording(timestamp: int) -> Recording:
"""Get a recording by timestamp.
Expand Down
148 changes: 145 additions & 3 deletions openadapt/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
Module: db.py
"""

from typing import Any
import os
import time

from dictalchemy import DictableModel
from loguru import logger
from sqlalchemy import create_engine, event
from sqlalchemy.engine import reflection
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData
import sqlalchemy as sa

from openadapt.config import DB_ECHO, DB_URL
from openadapt import config

NAMING_CONVENTION = {
"ix": "ix_%(column_0_label)s",
Expand Down Expand Up @@ -41,8 +48,8 @@ def __repr__(self) -> str:
def get_engine() -> sa.engine:
"""Create and return a database engine."""
engine = sa.create_engine(
DB_URL,
echo=DB_ECHO,
config.DB_URL,
echo=config.DB_ECHO,
)
return engine

Expand All @@ -68,3 +75,138 @@ def get_base(engine: sa.engine) -> sa.engine:
engine = get_engine()
Base = get_base(engine)
Session = sessionmaker(bind=engine)


def copy_recording_data(
source_engine: sa.engine,
target_engine: sa.engine,
recording_id: int,
exclude_tables: tuple = (),
) -> str:
"""Copy a specific recording from the source database to the target database.
Args:
source_engine (create_engine): SQLAlchemy engine for the source database.
target_engine (create_engine): SQLAlchemy engine for the target database.
recording_id (int): The ID of the recording to copy.
exclude_tables (tuple, optional): Tables excluded from copying. Defaults to ().
Returns:
str: The URL or path of the target database.
"""
try:
with source_engine.connect() as src_conn, target_engine.connect() as tgt_conn:
src_metadata = MetaData()
tgt_metadata = MetaData()

@event.listens_for(src_metadata, "column_reflect")
def genericize_datatypes(
inspector: reflection.Inspector,
tablename: str,
column_dict: dict[str, Any],
) -> None:
column_dict["type"] = column_dict["type"].as_generic(
allow_nulltype=True
)

tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Drop all tables in target database (except excluded tables)
for table in reversed(tgt_metadata.sorted_tables):
if table.name not in exclude_tables:
logger.info("Dropping table =", table.name)
table.drop(bind=target_engine)

tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Create all tables in target database (except excluded tables)
for table in src_metadata.sorted_tables:
if table.name not in exclude_tables:
table.create(bind=target_engine)

# Refresh metadata before copying data
tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)

# Get the source recording table
src_recording_table = src_metadata.tables["recording"]
tgt_recording_table = tgt_metadata.tables["recording"]

# Select the recording with the given recording_id from the source
src_select = src_recording_table.select().where(
src_recording_table.c.id == recording_id
)
src_recording = src_conn.execute(src_select).fetchone()

# Insert the recording into the target recording table
tgt_conn.execute(tgt_recording_table.insert().values(src_recording))

# Get the timestamp from the source recording
src_timestamp = src_recording["timestamp"]

# Copy data from tables with the same timestamp
for table in src_metadata.sorted_tables:
if (
table.name not in exclude_tables
and "recording_timestamp" in table.columns.keys()
):
# Select data from source table with the same timestamp
src_select = table.select().where(
table.c.recording_timestamp == src_timestamp
)
src_rows = src_conn.execute(src_select).fetchall()

# Insert data into target table
tgt_table = tgt_metadata.tables[table.name]
for row in src_rows:
tgt_insert = tgt_table.insert().values(**row._asdict())
tgt_conn.execute(tgt_insert)

# Copy data from alembic_version table
src_alembic_version_table = src_metadata.tables["alembic_version"]
tgt_alembic_version_table = tgt_metadata.tables["alembic_version"]
src_alembic_version_select = src_alembic_version_table.select()
src_alembic_version_data = src_conn.execute(
src_alembic_version_select
).fetchall()
for row in src_alembic_version_data:
tgt_alembic_version_insert = tgt_alembic_version_table.insert().values(
row
)
tgt_conn.execute(tgt_alembic_version_insert)

# Commit the transaction
tgt_conn.commit()

except Exception as exc:
# Perform cleanup
db_file_path = target_engine.url.database
if db_file_path and os.path.exists(db_file_path):
os.remove(db_file_path)
logger.exception(exc)
return ""

return target_engine.url.database


def export_recording(recording_id: int) -> str:
"""Export a recording by its ID to a new SQLite database.
Args:
recording_id (int): The ID of the recording to export.
Returns:
str: The file path of the new database with timestamp.
"""
timestamp = int(time.time())
db_fname = f"recording_{recording_id}_{timestamp}.db"
target_path = config.RECORDING_DIRECTORY_PATH / db_fname
target_db_url = f"sqlite:///{target_path}"

target_engine = create_engine(target_db_url, future=True)

db_file_path = copy_recording_data(engine, target_engine, recording_id)
return db_file_path
Loading

0 comments on commit cf81f1a

Please sign in to comment.