Skip to content

Commit

Permalink
modify export_sql to use paramerterized queries to prevent sql injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustaballer committed Jun 23, 2023
1 parent 0a0208e commit 9dc1850
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
...
"""

import multiprocessing
import os
import pathlib
Expand Down
34 changes: 26 additions & 8 deletions openadapt/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def export_sql(recording_id):
recording_id (int): The ID of the recording.
Returns:
str: The SQL statements to insert the recording into the output file.
str: The SQL statement to insert the recording into the output file.
"""
engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
Expand All @@ -139,15 +139,30 @@ def export_sql(recording_id):
recording = get_recording_by_id(recording_id)

if recording:
sql = f"INSERT INTO recording VALUES ({recording.id}, {recording.timestamp}, {recording.monitor_width}, {recording.monitor_height}, {recording.double_click_interval_seconds}, {recording.double_click_distance_pixels}, '{recording.platform}', '{recording.task_description}')"
sql = """
INSERT INTO recording
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
values = (
recording.id,
recording.timestamp,
recording.monitor_width,
recording.monitor_height,
recording.double_click_interval_seconds,
recording.double_click_distance_pixels,
recording.platform,
recording.task_description
)

logger.info(f"Recording with ID {recording_id} exported successfully.")
else:
sql = ""
logger.info(f"No recording found with ID {recording_id}.")

return sql
return sql, values


def create_db(recording_id, sql):
def create_db(recording_id, sql, values):
"""Create a new database and import the recording data.
Args:
Expand All @@ -161,7 +176,10 @@ def create_db(recording_id, sql):
db_fname = f"recording_{recording_id}.db"

t = time.time()
shutil.copyfile(config.ENV_FILE_PATH, f"{config.ENV_FILE_PATH}-{t}")
source_file_path = config.ENV_FILE_PATH
target_file_path = f"{config.ENV_FILE_PATH}-{t}"
logger.info(f"source_file_path={source_file_path}, target_file_path={target_file_path}")
shutil.copyfile(source_file_path, target_file_path)
config.set_db_url(db_fname)

with open(config.ENV_FILE_PATH, "r") as f:
Expand All @@ -177,7 +195,7 @@ def create_db(recording_id, sql):
db.engine = engine

with engine.begin() as connection:
connection.execute(sql)
connection.execute(sql, values)

db_file_path = config.DB_FPATH.resolve()

Expand Down Expand Up @@ -205,8 +223,8 @@ def export_recording(recording_id):
Returns:
str: The file path of the new database.
"""
sql = export_sql(recording_id)
t, db_file_path = create_db(recording_id, sql)
sql, values = export_sql(recording_id)
t, db_file_path = create_db(recording_id, sql, values)
restore_db(t)
return db_file_path

Expand Down
1 change: 1 addition & 0 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

EMPTY = (None, [], {}, "")


def get_now_dt_str(dt_format=config.DT_FMT):
"""
Get the current date and time as a formatted string.
Expand Down

0 comments on commit 9dc1850

Please sign in to comment.