Skip to content

Commit

Permalink
resolve #441
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer committed Aug 2, 2023
1 parent c3173f8 commit 0123aae
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 49 deletions.
25 changes: 22 additions & 3 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,30 @@ def persist_env(var_name: str, val: str, env_file_path: str = ENV_FILE_PATH) ->
locals()[key] = val

ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
DB_FPATH = ROOT_DIRPATH / DB_FNAME # type: ignore # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
DIRNAME_PERFORMANCE_PLOTS = "performance"
DATA_DIRECTORY_PATH = ROOT_DIRPATH / "data"
RECORDING_DIRECTORY_PATH = DATA_DIRECTORY_PATH / "recordings"
if DB_FNAME == "openadapt.db": # noqa
DB_FPATH = ROOT_DIRPATH / DB_FNAME # noqa
else:
DB_FPATH = RECORDING_DIRECTORY_PATH / DB_FNAME # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
DIRNAME_PERFORMANCE_PLOTS = "performance"


def set_db_url(db_fname: str) -> None:
"""Set the database URL based on the given database file name.
Args:
db_fname (str): The database file name.
"""
global DB_FNAME, DB_FPATH, DB_URL
DB_FNAME = db_fname
if DB_FNAME == "openadapt.db": # noqa
DB_FPATH = ROOT_DIRPATH / DB_FNAME # noqa
else:
DB_FPATH = RECORDING_DIRECTORY_PATH / DB_FNAME # noqa
DB_URL = f"sqlite:///{DB_FPATH}"
logger.info(f"{DB_URL=}")


def obfuscate(val: str, pct_reveal: float = 0.1, char: str = "*") -> str:
Expand Down
43 changes: 26 additions & 17 deletions openadapt/share.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@

from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
import fire
import sqlalchemy as sa

from openadapt import config, db, models, utils, visualize
from openadapt import config, db, utils, visualize

LOG_LEVEL = "INFO"
utils.configure_logging(logger, LOG_LEVEL)
Expand Down Expand Up @@ -139,30 +138,40 @@ def visualize_recording(db_name: str) -> None:
db_name (str): The name of the SQLite database containing the recording.
Raises:
sqlalchemy.exc.OperationalError: If there is an error accessing the database.
Exception: If there is an error accessing the database.
"""
# Determine the recording path based on the database name
if db_name == "openadapt.db":
recording_path = os.path.join(config.ROOT_DIRPATH, db_name)
else:
recording_path = os.path.join(config.RECORDING_DIRECTORY_PATH, db_name)

recording_url = f"sqlite:///{recording_path}"

engine = create_engine(recording_url, future=True)
Session = sessionmaker(bind=engine)
session = Session()
# Save the old value of DB_FNAME
old_val = os.getenv("DB_FNAME")

# Call visualize.main() passing the recording object
recording = (
session.query(models.Recording)
.order_by(sa.desc(models.Recording.timestamp))
.limit(1)
.first()
)
# Update the environment variable DB_FNAME and persist it
config.persist_env("DB_FNAME", db_name)
os.environ["DB_FNAME"] = db_name

# Call the main function from visualize.py and pass the recording object
visualize.main(recording)
# Set the database URL
config.set_db_url(db_name)

session.close()
engine = create_engine(recording_url)

try:
with Session(engine) as session:
os.system("alembic upgrade head")
# Visualize the recording
visualize.main(session)
except Exception as exc:
# Handle any exceptions that may occur during visualization
logger.exception(exc)
finally:
# Restore the old value of DB_FNAME in case of exceptions or at the end
os.environ["DB_FNAME"] = old_val
config.persist_env("DB_FNAME", old_val)


# Create a command-line interface using python-fire and utils.get_functions
Expand Down
15 changes: 8 additions & 7 deletions openadapt/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from loguru import logger
from tqdm import tqdm

from openadapt import config, models
from openadapt import config, crud
from openadapt.crud import get_latest_recording
from openadapt.events import get_events
from openadapt.utils import (
Expand Down Expand Up @@ -182,15 +182,16 @@ def dict2html(


@logger.catch
def main(recording: models.Recording = None) -> None:
def main(session=None) -> None:
"""Main function to generate an HTML report for a recording."""
if session is not None:
crud.db = session
configure_logging(logger, LOG_LEVEL)

if recording is None:
recording = get_latest_recording()
if SCRUB:
scrub.scrub_text(recording.task_description)
logger.debug(f"{recording=}")
recording = get_latest_recording()
if SCRUB:
scrub.scrub_text(recording.task_description)
logger.debug(f"{recording=}")

meta = {}
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
Expand Down
22 changes: 0 additions & 22 deletions tests/openadapt/test_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,3 @@ def test_receive_recording() -> None:

# Verify that the zip file has been deleted
assert not os.path.exists(temp_zip_path)


# Test visualize_recording function
def test_visualize_recording(setup_database: engine) -> None:
"""Tests the visualize_recording function.
This test calls the function being tested with the "recording.db" created from
the setup_database fixture and asserts that the session object
was closed after calling the function.
Args:
setup_database: The setup_database fixture from the testing environment.
Returns:
None
"""
# Call the function being tested
share.visualize_recording("recording.db")

# Assert that the session object was closed after calling the function
# Here we are checking if the engine is disposed after calling the function
assert not hasattr(share.visualize_recording, "engine")

0 comments on commit 0123aae

Please sign in to comment.