diff --git a/openadapt/db.py b/openadapt/db.py index df476996f..c5e5c6d4c 100644 --- a/openadapt/db.py +++ b/openadapt/db.py @@ -72,89 +72,90 @@ def copy_recording_data( str: The URL or path of the target database. """ try: - src_metadata = MetaData() - tgt_metadata = MetaData() - - @event.listens_for(src_metadata, "column_reflect") - def genericize_datatypes(inspector, tablename, column_dict): - column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True) - - src_conn = source_engine.connect() - tgt_conn = target_engine.connect() - tgt_metadata.reflect(bind=target_engine) - src_metadata.reflect(bind=source_engine) - - # Drop all tables in target database - for table in reversed(tgt_metadata.sorted_tables): - if table.name not in exclude_tables: - logger.info("Dropping table =", table.name) - table.drop(bind=target_engine) - - tgt_metadata.clear() - tgt_metadata.reflect(bind=target_engine) - src_metadata.reflect(bind=source_engine) - - # Create all tables in target database - for table in src_metadata.sorted_tables: - if table.name not in exclude_tables: - table.create(bind=target_engine) - - # Refresh metadata before copying data - tgt_metadata.clear() - tgt_metadata.reflect(bind=target_engine) - - # Get the source recording table - src_recording_table = src_metadata.tables["recording"] - tgt_recording_table = tgt_metadata.tables["recording"] - - # Select the recording with the given recording_id from the source - src_select = src_recording_table.select().where( - src_recording_table.c.id == recording_id - ) - src_recording = src_conn.execute(src_select).fetchone() - - # Insert the recording into the target recording table - tgt_insert = tgt_recording_table.insert().values(src_recording) - tgt_conn.execute(tgt_insert) - - # Get the timestamp from the source recording - src_timestamp = src_recording["timestamp"] - - # Copy data from tables with the same timestamp - for table in src_metadata.sorted_tables: - if ( - table.name not in exclude_tables - and "recording_timestamp" in table.columns.keys() - ): - # Select data from source table with the same timestamp - src_select = table.select().where( - table.c.recording_timestamp == src_timestamp + with source_engine.connect() as src_conn, target_engine.connect() as tgt_conn: + src_metadata = MetaData() + tgt_metadata = MetaData() + + @event.listens_for(src_metadata, "column_reflect") + def genericize_datatypes(inspector, tablename, column_dict): + column_dict["type"] = column_dict["type"].as_generic( + allow_nulltype=True + ) + + tgt_metadata.reflect(bind=target_engine) + src_metadata.reflect(bind=source_engine) + + # Drop all tables in target database (except excluded tables) + for table in reversed(tgt_metadata.sorted_tables): + if table.name not in exclude_tables: + logger.info("Dropping table =", table.name) + table.drop(bind=target_engine) + + tgt_metadata.clear() + tgt_metadata.reflect(bind=target_engine) + src_metadata.reflect(bind=source_engine) + + # Create all tables in target database (except excluded tables) + for table in src_metadata.sorted_tables: + if table.name not in exclude_tables: + table.create(bind=target_engine) + + # Refresh metadata before copying data + tgt_metadata.clear() + tgt_metadata.reflect(bind=target_engine) + + # Get the source recording table + src_recording_table = src_metadata.tables["recording"] + tgt_recording_table = tgt_metadata.tables["recording"] + + # Select the recording with the given recording_id from the source + src_select = src_recording_table.select().where( + src_recording_table.c.id == recording_id + ) + src_recording = src_conn.execute(src_select).fetchone() + + # Insert the recording into the target recording table + tgt_conn.execute(tgt_recording_table.insert().values(src_recording)) + + # Get the timestamp from the source recording + src_timestamp = src_recording["timestamp"] + + # Copy data from tables with the same timestamp + for table in src_metadata.sorted_tables: + if ( + table.name not in exclude_tables + and "recording_timestamp" in table.columns.keys() + ): + # Select data from source table with the same timestamp + src_select = table.select().where( + table.c.recording_timestamp == src_timestamp + ) + src_rows = src_conn.execute(src_select).fetchall() + + # Insert data into target table + tgt_table = tgt_metadata.tables[table.name] + for row in src_rows: + tgt_insert = tgt_table.insert().values(**row._asdict()) + tgt_conn.execute(tgt_insert) + + # Copy data from alembic_version table + src_alembic_version_table = src_metadata.tables["alembic_version"] + tgt_alembic_version_table = tgt_metadata.tables["alembic_version"] + src_alembic_version_select = src_alembic_version_table.select() + src_alembic_version_data = src_conn.execute( + src_alembic_version_select + ).fetchall() + for row in src_alembic_version_data: + tgt_alembic_version_insert = tgt_alembic_version_table.insert().values( + row ) - src_rows = src_conn.execute(src_select).fetchall() - - # Insert data into target table - tgt_table = tgt_metadata.tables[table.name] - for row in src_rows: - tgt_insert = tgt_table.insert().values(**row._asdict()) - tgt_conn.execute(tgt_insert) - - # Copy data from alembic_version table - src_alembic_version_table = src_metadata.tables["alembic_version"] - tgt_alembic_version_table = tgt_metadata.tables["alembic_version"] - src_alembic_version_select = src_alembic_version_table.select() - src_alembic_version_data = src_conn.execute( - src_alembic_version_select - ).fetchall() - for row in src_alembic_version_data: - tgt_alembic_version_insert = tgt_alembic_version_table.insert().values(row) - tgt_conn.execute(tgt_alembic_version_insert) - - tgt_conn.commit() - src_conn.close() - tgt_conn.close() + tgt_conn.execute(tgt_alembic_version_insert) + + # Commit the transaction + tgt_conn.commit() + except Exception as exc: # Perform cleanup - tgt_conn.close() db_file_path = target_engine.url.database if db_file_path and os.path.exists(db_file_path): os.remove(db_file_path)