Skip to content

Commit

Permalink
moved functions in crud.py to db.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer committed Jun 26, 2023
1 parent f58fd5f commit dafca05
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 130 deletions.
124 changes: 2 additions & 122 deletions openadapt/crud.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import os
import time
import shutil

from datetime import datetime
from loguru import logger
from sqlalchemy.orm import sessionmaker
import sqlalchemy as sa

from openadapt.db import Session, get_base, get_engine, engine
from openadapt import config
from openadapt.db import Session
from openadapt.models import (
ActionEvent,
Screenshot,
Expand Down Expand Up @@ -120,120 +113,7 @@ def get_latest_recording():


def get_recording_by_id(recording_id):
return db.query(Recording).filter_by(id=recording_id).first()


def export_sql(recording_id):
"""Export the recording data as SQL statements.
Args:
recording_id (int): The ID of the recording.
Returns:
str: The SQL statement to insert the recording into the output file.
"""
engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()

recording = get_recording_by_id(recording_id)

if recording:
sql = """
INSERT INTO recording
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
values = (
recording.id,
recording.timestamp,
recording.monitor_width,
recording.monitor_height,
recording.double_click_interval_seconds,
recording.double_click_distance_pixels,
recording.platform,
recording.task_description,
)

logger.info(f"Recording with ID {recording_id} exported successfully.")
else:
sql = ""
logger.info(f"No recording found with ID {recording_id}.")

return sql, values


def create_db(recording_id, sql, values):
"""Create a new database and import the recording data.
Args:
recording_id (int): The ID of the recording.
sql (str): The SQL statements to import the recording.
Returns:
tuple: A tuple containing the timestamp and the file path of the new database.
"""
db.close()
db_fname = f"recording_{recording_id}.db"

timestamp = time.time()
source_file_path = config.ENV_FILE_PATH
target_file_path = f"{config.ENV_FILE_PATH}-{timestamp}"
logger.info(
f"source_file_path={source_file_path}, target_file_path={target_file_path}"
)
shutil.copyfile(source_file_path, target_file_path)
config.set_db_url(db_fname)

with open(config.ENV_FILE_PATH, "r") as env_file:
env_file_lines = [
f"DB_FNAME={db_fname}\n"
if env_file_line.startswith("DB_FNAME")
else env_file_line
for env_file_line in env_file.readlines()
]

with open(config.ENV_FILE_PATH, "w") as env_file:
env_file.writelines(env_file_lines)

engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()
os.system("alembic upgrade head")
db.engine = engine

with engine.begin() as connection:
connection.execute(sql, values)

db_file_path = config.DB_FPATH.resolve()

return timestamp, db_file_path


def restore_db(timestamp):
"""Restore the database to a previous state.
Args:
timestamp (float): The timestamp associated with the backup file.
"""
backup_file = f"{config.ENV_FILE_PATH}-{timestamp}"
shutil.copyfile(backup_file, config.ENV_FILE_PATH)
config.set_db_url("openadapt.db")
db.engine = get_engine()


def export_recording(recording_id):
"""Export a recording by creating a new database, importing the recording, and then restoring the previous state.
Args:
recording_id (int): The ID of the recording to export.
Returns:
str: The file path of the new database.
"""
sql, values = export_sql(recording_id)
timestamp, db_file_path = create_db(recording_id, sql, values)
restore_db(timestamp)
return db_file_path
return db.query(Recording).filter_by(id=recording_id).first().all()


def get_recording(timestamp):
Expand Down
130 changes: 124 additions & 6 deletions openadapt/db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
import time
import shutil

import sqlalchemy as sa

from loguru import logger
from dictalchemy import DictableModel
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData
from sqlalchemy.ext.declarative import declarative_base

from openadapt.config import DB_ECHO, DB_URL
from openadapt.utils import EMPTY, row2dict
from openadapt import config, utils


NAMING_CONVENTION = {
Expand All @@ -24,16 +29,16 @@ class BaseModel(DictableModel):
def __repr__(self):
params = ", ".join(
f"{k}={v!r}" # !r converts value to string using repr (adds quotes)
for k, v in row2dict(self, follow=False).items()
if v not in EMPTY
for k, v in utils.row2dict(self, follow=False).items()
if v not in utils.EMPTY
)
return f"{self.__class__.__name__}({params})"


def get_engine():
engine = sa.create_engine(
DB_URL,
echo=DB_ECHO,
config.DB_URL,
echo=config.DB_ECHO,
)
return engine

Expand All @@ -51,3 +56,116 @@ def get_base(engine):
engine = get_engine()
Base = get_base(engine)
Session = sessionmaker(bind=engine)

def export_sql(recording_id):
from openadapt.crud import get_recording_by_id # to avoid circular import
"""Export the recording data as SQL statements.
Args:
recording_id (int): The ID of the recording.
Returns:
str: The SQL statement to insert the recording into the output file.
"""
engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()

recording = get_recording_by_id(recording_id)

if recording:
sql = """
INSERT INTO recording
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
values = (
recording.id,
recording.timestamp,
recording.monitor_width,
recording.monitor_height,
recording.double_click_interval_seconds,
recording.double_click_distance_pixels,
recording.platform,
recording.task_description,
)

logger.info(f"Recording with ID {recording_id} exported successfully.")
else:
sql = ""
logger.info(f"No recording found with ID {recording_id}.")

return sql, values


def create_db(recording_id, sql, values):
"""Create a new database and import the recording data.
Args:
recording_id (int): The ID of the recording.
sql (str): The SQL statements to import the recording.
Returns:
tuple: A tuple containing the timestamp and the file path of the new database.
"""
# engine.close()
db_fname = f"recording_{recording_id}.db"

timestamp = time.time()
source_file_path = config.ENV_FILE_PATH
target_file_path = f"{config.ENV_FILE_PATH}-{timestamp}"
logger.info(
f"source_file_path={source_file_path}, target_file_path={target_file_path}"
)
shutil.copyfile(source_file_path, target_file_path)
config.set_db_url(db_fname)

with open(config.ENV_FILE_PATH, "r") as env_file:
env_file_lines = [
f"DB_FNAME={db_fname}\n"
if env_file_line.startswith("DB_FNAME")
else env_file_line
for env_file_line in env_file.readlines()
]

with open(config.ENV_FILE_PATH, "w") as env_file:
env_file.writelines(env_file_lines)

engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()
os.system("alembic upgrade head")
# db.engine = engine

with engine.begin() as connection:
connection.execute(sql, values)

db_file_path = config.DB_FPATH.resolve()

return timestamp, db_file_path


def restore_db(timestamp):
"""Restore the database to a previous state.
Args:
timestamp (float): The timestamp associated with the backup file.
"""
backup_file = f"{config.ENV_FILE_PATH}-{timestamp}"
shutil.copyfile(backup_file, config.ENV_FILE_PATH)
config.set_db_url("openadapt.db")
engine = get_engine()


def export_recording(recording_id):
"""Export a recording by creating a new database, importing the recording, and then restoring the previous state.
Args:
recording_id (int): The ID of the recording to export.
Returns:
str: The file path of the new database.
"""
sql, values = export_sql(recording_id)
timestamp, db_file_path = create_db(recording_id, sql, values)
restore_db(timestamp)
return db_file_path
4 changes: 2 additions & 2 deletions openadapt/share.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from loguru import logger
import fire

from openadapt import config, crud, utils
from openadapt import config, db, utils


LOG_LEVEL = "INFO"
Expand All @@ -28,7 +28,7 @@ def export_recording_to_folder(recording_id):
Returns:
str: The path of the created zip file.
"""
recording_db_path = crud.export_recording(recording_id)
recording_db_path = db.export_recording(recording_id)

if recording_db_path:
# Create the directory if it doesn't exist
Expand Down

0 comments on commit dafca05

Please sign in to comment.