Skip to content

Commit

Permalink
Centralize SQL execution
Browse files Browse the repository at this point in the history
  • Loading branch information
krysal committed Apr 16, 2024
1 parent f863b52 commit 377c8f1
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,43 @@
logger = logging.getLogger(__name__)


def get_null_counts(
dag_task: AbstractOperator,
def run_sql(
sql: str,
log_sql: bool = True,
method: str = "get_records",
handler: callable = None,
autocommit: bool = False,
postgres_conn_id: str = POSTGRES_CONN_ID,
) -> int:
dag_task: AbstractOperator = None,
):
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(dag_task),
log_sql=log_sql,
)
nulls_count = postgres.get_first(
dedent("SELECT COUNT(*) from image WHERE meta_data->>'license_url' IS NULL")
)[0]
return nulls_count
if method == "get_records":
return postgres.get_records(sql)
elif method == "get_first":
return postgres.get_first(sql)
else:
return postgres.run(sql, autocommit=autocommit, handler=handler)


@task
def get_license_groups(
dag_task: AbstractOperator = None, postgres_conn_id: str = POSTGRES_CONN_ID
) -> list[tuple[str, str]]:
def get_license_groups(dag_task: AbstractOperator = None) -> list[tuple[str, str]]:
"""
Get license groups of rows that don't have a `license_url` in their
`meta_data` field.
:return: List of (license, version) tuples.
"""
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(dag_task),
)

select_query = dedent("""
SELECT license, license_version, count(identifier)
FROM image WHERE meta_data->>'license_url' IS NULL
GROUP BY license, license_version
""")
license_groups = postgres.get_records(select_query)
license_groups = run_sql(select_query, dag_task=dag_task)

total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
Expand All @@ -89,39 +91,30 @@ def get_license_groups(
return [(group[0], group[1]) for group in license_groups]


@task
@task(max_active_tis_per_dag=1)
def update_license_url(
license_group: tuple[str, str],
batch_size: int,
dag_task: AbstractOperator = None,
postgres_conn_id: str = POSTGRES_CONN_ID,
) -> int:
"""
Add license_url to meta_data batching all records with the same license.
:param license_group: tuple of license and version
:param batch_size: number of records to update in one update statement
:param dag_task: automatically passed by Airflow, used to set the execution timeout
:param postgres_conn_id: Postgres connection id
:param dag_task: automatically passed by Airflow, used to set the execution timeout.
"""

postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(dag_task),
)
license_, version = license_group
*_, license_url = get_license_info_from_license_pair(license_, version)
license_url_dict = {"license_url": license_url}
total_updated = 0

if license_url is None:
logger.warning(f"No license pair ({license_}, {version}) in the license map.")
return 0

logging.info(
f"Will update license_url in `meta_data` for {license_} {version}"
f"to {license_url}."
f"Will add `license_url` in `meta_data` for records with license "
f"{license_} {version} to {license_url}."
)
license_url_dict = {"license_url": license_url}

# Merge existing metadata with the new license_url
update_query = dedent(
Expand All @@ -138,11 +131,16 @@ def update_license_url(
);
"""
)

total_updated = 0
updated_count = 1
while updated_count:
updated_count = postgres.run(
update_query, autocommit=True, handler=RETURN_ROW_COUNT
updated_count = run_sql(
update_query,
log_sql=total_updated == 0,
method="run",
handler=RETURN_ROW_COUNT,
autocommit=True,
dag_task=dag_task,
)
total_updated += updated_count
logger.info(f"Updated {total_updated} rows with {license_url}.")
Expand All @@ -160,7 +158,9 @@ def final_report(updated, dag_task: AbstractOperator = None):
:param dag_task: automatically passed by Airflow, used to set the execution timeout.
"""
total_updated = sum(updated)
null_counts = get_null_counts(dag_task)
query = "SELECT COUNT(*) from image WHERE meta_data->>'license_url' IS NULL"
null_counts = run_sql(query, method="get_first", dag_task=dag_task)[0]

message = f"""
`{DAG_ID}` DAG run completed. Updated {total_updated} records with `license_url` in the
`meta_data` field. {null_counts} records left pending.
Expand Down

0 comments on commit 377c8f1

Please sign in to comment.