Skip to content
This repository has been archived by the owner on Aug 4, 2023. It is now read-only.

Use Python to group items by license to speed up the query #1045

Merged
merged 4 commits into from
Mar 16, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions openverse_catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
the `meta_data` column are updated, the DAG will only run the first and the
last step, logging the statistics.
"""
import json
import logging
from datetime import timedelta
from textwrap import dedent
Expand All @@ -17,8 +18,9 @@
from airflow.models import DAG
from airflow.models.abstractoperator import AbstractOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.utils.trigger_rule import TriggerRule
from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID, XCOM_PULL_TEMPLATE
from common.licenses.constants import get_reverse_license_path_map
from common.licenses import get_license_info_from_license_pair
from common.loader.sql import RETURN_ROW_COUNT
from common.slack import send_message
from common.sql import PostgresHook
Expand Down Expand Up @@ -65,7 +67,7 @@ def get_statistics(
return next_task


def update_license_url(postgres_conn_id: str, task: AbstractOperator) -> dict[str, int]:
def update_license_url(postgres_conn_id: str, task: AbstractOperator) -> str | None:
"""Add license_url to meta_data batching all records with the same license.
:param task: automatically passed by Airflow, used to set the execution timeout
:param postgres_conn_id: Postgres connection id
Expand All @@ -76,60 +78,85 @@ def update_license_url(postgres_conn_id: str, task: AbstractOperator) -> dict[st
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(task),
)
license_map = get_reverse_license_path_map()

total_count = 0
total_counts = {}
for license_items, path in license_map.items():
license_name, license_version = license_items
logger.info(f"Processing {license_name} {license_version}, {license_items}.")
license_url = f"{base_url}{path}/"

select_query = dedent(
f"""
SELECT identifier FROM image
WHERE (
meta_data is NULL AND license = '{license_name}'
AND license_version = '{license_version}')
"""
)
result = postgres.get_records(select_query)

if not result:
logger.info(f"No records to update with {license_url}.")
select_query = dedent(
"""
SELECT identifier, license, license_version
FROM image WHERE meta_data IS NULL;"""
)
records_with_null_in_metadata = postgres.get_records(select_query)

# Dictionary with license pair as key and list of identifiers as value
records_to_update = {}

for result in records_with_null_in_metadata:
identifier, license_, version = result
license_pair = f"{license_},{version}"
if license_pair not in records_to_update:
records_to_update[license_pair] = [identifier]
else:
records_to_update[license_pair].append(identifier)
obulat marked this conversation as resolved.
Show resolved Hide resolved

total_updated = 0
updated_by_license = {}

for license_pair, identifiers in records_to_update.items():
license_, license_version = license_pair.split(",")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving this as a string only to split it out again, we could probably use the license & license version pair in tuple form as the key itself, e.g. records_to_update[license_, license_version]. Then we can unpack it directly here too, e.g. for (license_, license_version), identifiers in records_to_update.items().

license_url = get_license_info_from_license_pair(license_, license_version)[-1]
obulat marked this conversation as resolved.
Show resolved Hide resolved
if license_url is None:
logger.info(
f"No license pair {license_pair} in the license reverse path map."
)
continue
logger.info(f"{len(result)} records to update with {license_url}.")
license_url_col = {"license_url": license_url}
update_license_url_query = dedent(
logger.info(f"{len(identifiers):4} items will be updated with {license_url}")
license_url_dict = {"license_url": license_url}
update_query = dedent(
f"""
UPDATE image
SET meta_data = {Json(license_url_col)}
WHERE identifier IN ({','.join([f"'{r[0]}'" for r in result])});
SET meta_data = {Json(license_url_dict)}
WHERE identifier IN ({','.join([f"'{r}'" for r in identifiers])});
"""
)

updated_count = postgres.run(
update_license_url_query, autocommit=True, handler=RETURN_ROW_COUNT
updated_count: int = postgres.run(
update_query, autocommit=True, handler=RETURN_ROW_COUNT
)
logger.info(f"{updated_count} records updated with {license_url}.")
total_counts[license_url] = updated_count
total_count += updated_count

logger.info(f"{total_count} image records with missing license_url updated.")
for license_url, count in total_counts.items():
logger.info(f"{count} records with {license_url}.")
return total_counts
if updated_count:
updated_by_license[license_url] = updated_count
total_updated += updated_count
logger.info(f"Updated {total_updated} rows")
return json.dumps(updated_by_license)
obulat marked this conversation as resolved.
Show resolved Hide resolved


def final_report(
postgres_conn_id: str,
item_count: int,
task: AbstractOperator,
updated_by_license: str,
task: AbstractOperator = None,
):
"""Check for null in `meta_data` and send a message to Slack
with the statistics of the DAG run.

:param postgres_conn_id: Postgres connection id.
:param updated_by_license: stringified JSON with the number of records updated
for each license_url. If `update_license_url` was skipped, this will be "None".
:param task: automatically passed by Airflow, used to set the execution timeout.
"""
null_meta_data_count = get_null_counts(postgres_conn_id, task)

if updated_by_license == "None":
updated_message = "No records were updated."
else:
updated_by_license = json.loads(updated_by_license)
formatted_item_count = "".join(
[
f"{license_url}: {count} rows\n"
for license_url, count in updated_by_license.items()
]
)
updated_message = f"Update statistics:\n{formatted_item_count}"
message = f"""
Added license_url to *{item_count}* items`
`add_license_url` DAG run completed.
{updated_message}
Now, there are {null_meta_data_count} records with NULL meta_data left.
"""
send_message(
Expand Down Expand Up @@ -169,9 +196,10 @@ def final_report(
final_report = PythonOperator(
task_id=FINAL_REPORT,
python_callable=final_report,
trigger_rule=TriggerRule.ALL_DONE,
op_kwargs={
"postgres_conn_id": POSTGRES_CONN_ID,
"item_count": XCOM_PULL_TEMPLATE.format(
"updated_by_license": XCOM_PULL_TEMPLATE.format(
update_license_url.task_id, "return_value"
),
},
Expand Down