Skip to content

Commit

Permalink
fix(db): Database access refactor (#676)
Browse files Browse the repository at this point in the history
* feat: Remove global sessions, and introduce read only sessions for cases where no writing is required

* refactor: Rename db to session

* feat: Raise exceptions if commit/write/delete is attempted on a read-only session

* feat: Add tests for the read only session

* chore: lint using flake8

* rename test_database -> db_engine

---------

Co-authored-by: Richard Abrich <[email protected]>
  • Loading branch information
KIRA009 and abrichr authored May 28, 2024
1 parent b438c9c commit 987f6ac
Show file tree
Hide file tree
Showing 24 changed files with 483 additions and 327 deletions.
8 changes: 4 additions & 4 deletions experiments/imagesimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable
import time

from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from PIL import Image, ImageOps
from skimage.metrics import structural_similarity as ssim
from sklearn.manifold import MDS
Expand All @@ -12,8 +12,7 @@
import matplotlib.pyplot as plt
import numpy as np

from openadapt.db import crud

from openadapt.session import crud

SHOW_SSIM = False

Expand Down Expand Up @@ -290,7 +289,8 @@ def display_distance_matrix_with_images(

def main() -> None:
"""Main function to process images and display similarity metrics."""
recording = crud.get_latest_recording()
session = crud.get_new_session(read_only=True)
recording = crud.get_latest_recording(session)
action_events = recording.processed_action_events
images = [action_event.screenshot.cropped_image for action_event in action_events]

Expand Down
3 changes: 0 additions & 3 deletions openadapt/app/cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from openadapt.app.objects.local_file_picker import LocalFilePicker
from openadapt.app.util import get_scrub, set_dark, set_scrub, sync_switch
from openadapt.db.crud import new_session
from openadapt.record import record


Expand Down Expand Up @@ -146,7 +145,6 @@ def quick_record(
) -> None:
"""Run a recording session."""
global record_proc
new_session()
task_description = task_description or datetime.now().strftime("%d/%m/%Y %H:%M:%S")
record_proc.start(
record,
Expand Down Expand Up @@ -204,7 +202,6 @@ def begin() -> None:
ui.notify(
f"Recording {name}... Press CTRL + C in terminal window to cancel",
)
new_session()
global record_proc
record_proc.start(
record,
Expand Down
97 changes: 46 additions & 51 deletions openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,27 @@ def attach_routes(self) -> APIRouter:
@staticmethod
def get_recordings() -> dict[str, list[Recording]]:
"""Get all recordings."""
session = crud.get_new_session()
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
return {"recordings": recordings}

@staticmethod
def get_scrubbed_recordings() -> dict[str, list[Recording]]:
"""Get all scrubbed recordings."""
session = crud.get_new_session()
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_scrubbed_recordings(session)
return {"recordings": recordings}

@staticmethod
async def start_recording() -> dict[str, str]:
def start_recording() -> dict[str, str | int]:
"""Start a recording session."""
await crud.acquire_db_lock()
cards.quick_record()
return {"message": "Recording started"}
return {"message": "Recording started", "status": 200}

@staticmethod
def stop_recording() -> dict[str, str]:
"""Stop a recording session."""
cards.stop_record()
crud.release_db_lock()
return {"message": "Recording stopped"}

@staticmethod
Expand All @@ -69,48 +67,45 @@ def recording_detail_route(self) -> None:
async def get_recording_detail(websocket: WebSocket, recording_id: int) -> None:
"""Get a specific recording and its action events."""
await websocket.accept()
session = crud.get_new_session()
with session:
recording = crud.get_recording_by_id(recording_id, session)

await websocket.send_json(
{"type": "recording", "value": recording.asdict()}
)

action_events = get_events(recording, session=session)

await websocket.send_json(
{"type": "num_events", "value": len(action_events)}
)

def convert_to_str(event_dict: dict) -> dict:
"""Convert the keys to strings."""
if "key" in event_dict:
event_dict["key"] = str(event_dict["key"])
if "canonical_key" in event_dict:
event_dict["canonical_key"] = str(event_dict["canonical_key"])
if "reducer_names" in event_dict:
event_dict["reducer_names"] = list(event_dict["reducer_names"])
if "children" in event_dict:
for child_event in event_dict["children"]:
convert_to_str(child_event)

for action_event in action_events:
event_dict = row2dict(action_event)
try:
image = display_event(action_event)
width, height = image.size
image = image2utf8(image)
except Exception:
logger.info("Failed to display event")
image = None
width, height = 0, 0
event_dict["screenshot"] = image
event_dict["dimensions"] = {"width": width, "height": height}

convert_to_str(event_dict)
await websocket.send_json(
{"type": "action_event", "value": event_dict}
)

await websocket.close()
session = crud.get_new_session(read_only=True)
recording = crud.get_recording_by_id(session, recording_id)

await websocket.send_json(
{"type": "recording", "value": recording.asdict()}
)

action_events = get_events(session, recording)

await websocket.send_json(
{"type": "num_events", "value": len(action_events)}
)

def convert_to_str(event_dict: dict) -> dict:
"""Convert the keys to strings."""
if "key" in event_dict:
event_dict["key"] = str(event_dict["key"])
if "canonical_key" in event_dict:
event_dict["canonical_key"] = str(event_dict["canonical_key"])
if "reducer_names" in event_dict:
event_dict["reducer_names"] = list(event_dict["reducer_names"])
if "children" in event_dict:
for child_event in event_dict["children"]:
convert_to_str(child_event)

for action_event in action_events:
event_dict = row2dict(action_event)
try:
image = display_event(action_event)
width, height = image.size
image = image2utf8(image)
except Exception:
logger.info("Failed to display event")
image = None
width, height = 0, 0
event_dict["screenshot"] = image
event_dict["dimensions"] = {"width": width, "height": height}

convert_to_str(event_dict)
await websocket.send_json({"type": "action_event", "value": event_dict})

await websocket.close()
2 changes: 0 additions & 2 deletions openadapt/app/dashboard/api/scrubbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from fastapi.responses import StreamingResponse

from openadapt.config import config
from openadapt.db import crud
from openadapt.privacy.providers import ScrubProvider
from openadapt.scrub import get_scrubbing_process, scrub

Expand Down Expand Up @@ -62,7 +61,6 @@ async def scrub_recording(recording_id: int, provider_id: str) -> dict[str, str]
}
if provider_id not in ScrubProvider.get_available_providers():
return {"message": "Provider not supported", "status": "failed"}
await crud.acquire_db_lock()
scrub(recording_id, provider_id, release_lock=True)
scrubbing_proc = get_scrubbing_process()
while not scrubbing_proc.is_running():
Expand Down
6 changes: 3 additions & 3 deletions openadapt/app/dashboard/components/Shell/Shell.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'use client'

import { AppShell, Burger, Image, Text } from '@mantine/core'
import { AppShell, Box, Burger, Image, Text } from '@mantine/core'
import React from 'react'
import { Navbar } from '../Navbar'
import { useDisclosure } from '@mantine/hooks'
Expand Down Expand Up @@ -30,12 +30,12 @@ export const Shell = ({ children }: Props) => {
hiddenFrom="sm"
size="sm"
/>
<Text className="h-full flex items-center px-5 gap-x-2">
<Box className="h-full flex items-center px-5 gap-x-2">
<Image src={logo.src} alt="OpenAdapt" w={40} />
<Text>
OpenAdapt.AI
</Text>
</Text>
</Box>
</AppShell.Header>

<AppShell.Navbar>
Expand Down
14 changes: 12 additions & 2 deletions openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __init__(self) -> None:

self.app.setQuitOnLastWindowClosed(False)

# since the lock is a file, delete it when starting the app so that
# new instances can start even if the previous one crashed
crud.release_db_lock(raise_exception=False)

# currently required for pyqttoast
# TODO: remove once https://github.com/niklashenning/pyqt-toast/issues/9
# is addressed
Expand Down Expand Up @@ -379,7 +383,12 @@ def _delete(self, recording: Recording) -> None:
"""
dialog = ConfirmDeleteDialog(recording.task_description)
if dialog.exec_():
crud.delete_recording(recording.timestamp)
if not crud.acquire_db_lock():
self.show_toast("Failed to delete recording. Try again later.")
return
with crud.get_new_session(read_and_write=True) as session:
crud.delete_recording(session, recording)
crud.release_db_lock()
self.show_toast("Recording deleted.")
self.populate_menus()

Expand Down Expand Up @@ -413,7 +422,8 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None
action (Callable): The function to call when the menu item is clicked.
action_type (str): The type of action to perform ["visualize", "replay"]
"""
recordings = crud.get_all_recordings()
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)

self.recording_actions[action_type] = []

Expand Down
9 changes: 5 additions & 4 deletions openadapt/app/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import click

from openadapt.config import config
from openadapt.db.crud import get_latest_recording, get_recording
from openadapt.db import crud
from openadapt.events import get_events
from openadapt.utils import (
EMPTY,
Expand Down Expand Up @@ -141,18 +141,19 @@ def main(timestamp: str) -> None:
configure_logging(logger, LOG_LEVEL)

ui_dark = ui.dark_mode(config.VISUALIZE_DARK_MODE)
session = crud.get_new_session(read_only=True)

if timestamp is None:
recording = get_latest_recording()
recording = crud.get_latest_recording(session)
else:
recording = get_recording(timestamp)
recording = crud.get_recording(session, timestamp)

if SCRUB:
scrub.scrub_text(recording.task_description)
logger.debug(f"{recording=}")

meta = {}
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
action_events = get_events(session, recording, process=PROCESS_EVENTS, meta=meta)
event_dicts = rows2dicts(action_events)

if SCRUB:
Expand Down
Loading

0 comments on commit 987f6ac

Please sign in to comment.