Skip to content
This repository has been archived by the owner on Aug 4, 2023. It is now read-only.

Use Python to group items by license to speed up the query #1045

Merged
merged 4 commits into from
Mar 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 59 additions & 39 deletions openverse_catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
last step, logging the statistics.
"""
import logging
from collections import defaultdict
from datetime import timedelta
from textwrap import dedent
from typing import Literal

from airflow.models import DAG
from airflow.models.abstractoperator import AbstractOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.utils.trigger_rule import TriggerRule
from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID, XCOM_PULL_TEMPLATE
from common.licenses.constants import get_reverse_license_path_map
from common.licenses import get_license_info_from_license_pair
from common.loader.sql import RETURN_ROW_COUNT
from common.slack import send_message
from common.sql import PostgresHook
Expand Down Expand Up @@ -76,60 +78,77 @@ def update_license_url(postgres_conn_id: str, task: AbstractOperator) -> dict[st
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(task),
)
license_map = get_reverse_license_path_map()

total_count = 0
total_counts = {}
for license_items, path in license_map.items():
license_name, license_version = license_items
logger.info(f"Processing {license_name} {license_version}, {license_items}.")
license_url = f"{base_url}{path}/"
select_query = dedent(
"""
SELECT identifier, license, license_version
FROM image WHERE meta_data IS NULL;"""
)
records_with_null_in_metadata = postgres.get_records(select_query)

select_query = dedent(
f"""
SELECT identifier FROM image
WHERE (
meta_data is NULL AND license = '{license_name}'
AND license_version = '{license_version}')
"""
)
result = postgres.get_records(select_query)
# Dictionary with license pair as key and list of identifiers as value
records_to_update = defaultdict(list)

for result in records_with_null_in_metadata:
identifier, license_, version = result
records_to_update[(license_, version)].append(identifier)

total_updated = 0
updated_by_license = {}

if not result:
logger.info(f"No records to update with {license_url}.")
for (license_, version), identifiers in records_to_update.items():
*_, license_url = get_license_info_from_license_pair(license_, version)
if license_url is None:
logger.info(f"No license pair ({license_}, {version}) in the license map.")
continue
logger.info(f"{len(result)} records to update with {license_url}.")
license_url_col = {"license_url": license_url}
update_license_url_query = dedent(
logger.info(f"{len(identifiers):4} items will be updated with {license_url}")
license_url_dict = {"license_url": license_url}
update_query = dedent(
f"""
UPDATE image
SET meta_data = {Json(license_url_col)}
WHERE identifier IN ({','.join([f"'{r[0]}'" for r in result])});
SET meta_data = {Json(license_url_dict)}
WHERE identifier IN ({','.join([f"'{r}'" for r in identifiers])});
"""
)

updated_count = postgres.run(
update_license_url_query, autocommit=True, handler=RETURN_ROW_COUNT
updated_count: int = postgres.run(
update_query, autocommit=True, handler=RETURN_ROW_COUNT
)
logger.info(f"{updated_count} records updated with {license_url}.")
total_counts[license_url] = updated_count
total_count += updated_count

logger.info(f"{total_count} image records with missing license_url updated.")
for license_url, count in total_counts.items():
logger.info(f"{count} records with {license_url}.")
return total_counts
if updated_count:
updated_by_license[license_url] = updated_count
total_updated += updated_count
logger.info(f"Updated {total_updated} rows")
return updated_by_license


def final_report(
postgres_conn_id: str,
item_count: int,
task: AbstractOperator,
updated_by_license: dict[str, int] | None,
task: AbstractOperator = None,
):
"""Check for null in `meta_data` and send a message to Slack
with the statistics of the DAG run.

:param postgres_conn_id: Postgres connection id.
:param updated_by_license: stringified JSON with the number of records updated
for each license_url. If `update_license_url` was skipped, this will be "None".
:param task: automatically passed by Airflow, used to set the execution timeout.
"""
null_meta_data_count = get_null_counts(postgres_conn_id, task)

if not updated_by_license:
updated_message = "No records were updated."
else:
formatted_item_count = "".join(
[
f"{license_url}: {count} rows\n"
for license_url, count in updated_by_license.items()
]
)
updated_message = f"Update statistics:\n{formatted_item_count}"
message = f"""
Added license_url to *{item_count}* items`
`add_license_url` DAG run completed.
{updated_message}
Now, there are {null_meta_data_count} records with NULL meta_data left.
"""
send_message(
Expand All @@ -150,9 +169,9 @@ def final_report(
},
schedule_interval=None,
catchup=False,
# Use the docstring at the top of the file as md docs in the UI
doc_md=__doc__,
tags=["data_normalization"],
render_template_as_native_obj=True,
)

with dag:
Expand All @@ -169,9 +188,10 @@ def final_report(
final_report = PythonOperator(
task_id=FINAL_REPORT,
python_callable=final_report,
trigger_rule=TriggerRule.ALL_DONE,
op_kwargs={
"postgres_conn_id": POSTGRES_CONN_ID,
"item_count": XCOM_PULL_TEMPLATE.format(
"updated_by_license": XCOM_PULL_TEMPLATE.format(
update_license_url.task_id, "return_value"
),
},
Expand Down