Skip to content

Commit

Permalink
refactor copy_recording_data
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer committed Jul 22, 2023
1 parent 9b7f7af commit 786e063
Showing 1 changed file with 81 additions and 80 deletions.
161 changes: 81 additions & 80 deletions openadapt/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,89 +72,90 @@ def copy_recording_data(
str: The URL or path of the target database.
"""
try:
src_metadata = MetaData()
tgt_metadata = MetaData()

@event.listens_for(src_metadata, "column_reflect")
def genericize_datatypes(inspector, tablename, column_dict):
column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True)

src_conn = source_engine.connect()
tgt_conn = target_engine.connect()
tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Drop all tables in target database
for table in reversed(tgt_metadata.sorted_tables):
if table.name not in exclude_tables:
logger.info("Dropping table =", table.name)
table.drop(bind=target_engine)

tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Create all tables in target database
for table in src_metadata.sorted_tables:
if table.name not in exclude_tables:
table.create(bind=target_engine)

# Refresh metadata before copying data
tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)

# Get the source recording table
src_recording_table = src_metadata.tables["recording"]
tgt_recording_table = tgt_metadata.tables["recording"]

# Select the recording with the given recording_id from the source
src_select = src_recording_table.select().where(
src_recording_table.c.id == recording_id
)
src_recording = src_conn.execute(src_select).fetchone()

# Insert the recording into the target recording table
tgt_insert = tgt_recording_table.insert().values(src_recording)
tgt_conn.execute(tgt_insert)

# Get the timestamp from the source recording
src_timestamp = src_recording["timestamp"]

# Copy data from tables with the same timestamp
for table in src_metadata.sorted_tables:
if (
table.name not in exclude_tables
and "recording_timestamp" in table.columns.keys()
):
# Select data from source table with the same timestamp
src_select = table.select().where(
table.c.recording_timestamp == src_timestamp
with source_engine.connect() as src_conn, target_engine.connect() as tgt_conn:
src_metadata = MetaData()
tgt_metadata = MetaData()

@event.listens_for(src_metadata, "column_reflect")
def genericize_datatypes(inspector, tablename, column_dict):
column_dict["type"] = column_dict["type"].as_generic(
allow_nulltype=True
)

tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Drop all tables in target database (except excluded tables)
for table in reversed(tgt_metadata.sorted_tables):
if table.name not in exclude_tables:
logger.info("Dropping table =", table.name)
table.drop(bind=target_engine)

tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)
src_metadata.reflect(bind=source_engine)

# Create all tables in target database (except excluded tables)
for table in src_metadata.sorted_tables:
if table.name not in exclude_tables:
table.create(bind=target_engine)

# Refresh metadata before copying data
tgt_metadata.clear()
tgt_metadata.reflect(bind=target_engine)

# Get the source recording table
src_recording_table = src_metadata.tables["recording"]
tgt_recording_table = tgt_metadata.tables["recording"]

# Select the recording with the given recording_id from the source
src_select = src_recording_table.select().where(
src_recording_table.c.id == recording_id
)
src_recording = src_conn.execute(src_select).fetchone()

# Insert the recording into the target recording table
tgt_conn.execute(tgt_recording_table.insert().values(src_recording))

# Get the timestamp from the source recording
src_timestamp = src_recording["timestamp"]

# Copy data from tables with the same timestamp
for table in src_metadata.sorted_tables:
if (
table.name not in exclude_tables
and "recording_timestamp" in table.columns.keys()
):
# Select data from source table with the same timestamp
src_select = table.select().where(
table.c.recording_timestamp == src_timestamp
)
src_rows = src_conn.execute(src_select).fetchall()

# Insert data into target table
tgt_table = tgt_metadata.tables[table.name]
for row in src_rows:
tgt_insert = tgt_table.insert().values(**row._asdict())
tgt_conn.execute(tgt_insert)

# Copy data from alembic_version table
src_alembic_version_table = src_metadata.tables["alembic_version"]
tgt_alembic_version_table = tgt_metadata.tables["alembic_version"]
src_alembic_version_select = src_alembic_version_table.select()
src_alembic_version_data = src_conn.execute(
src_alembic_version_select
).fetchall()
for row in src_alembic_version_data:
tgt_alembic_version_insert = tgt_alembic_version_table.insert().values(
row
)
src_rows = src_conn.execute(src_select).fetchall()

# Insert data into target table
tgt_table = tgt_metadata.tables[table.name]
for row in src_rows:
tgt_insert = tgt_table.insert().values(**row._asdict())
tgt_conn.execute(tgt_insert)

# Copy data from alembic_version table
src_alembic_version_table = src_metadata.tables["alembic_version"]
tgt_alembic_version_table = tgt_metadata.tables["alembic_version"]
src_alembic_version_select = src_alembic_version_table.select()
src_alembic_version_data = src_conn.execute(
src_alembic_version_select
).fetchall()
for row in src_alembic_version_data:
tgt_alembic_version_insert = tgt_alembic_version_table.insert().values(row)
tgt_conn.execute(tgt_alembic_version_insert)

tgt_conn.commit()
src_conn.close()
tgt_conn.close()
tgt_conn.execute(tgt_alembic_version_insert)

# Commit the transaction
tgt_conn.commit()

except Exception as exc:
# Perform cleanup
tgt_conn.close()
db_file_path = target_engine.url.database
if db_file_path and os.path.exists(db_file_path):
os.remove(db_file_path)
Expand Down

0 comments on commit 786e063

Please sign in to comment.