Skip to content

Commit

Permalink
Account for sql.SQL and sql.Composed Objects (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto authored Nov 11, 2024
1 parent f3525da commit f43c92e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
25 changes: 21 additions & 4 deletions pgtrigger/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@

if utils.psycopg_maj_version == 2:
import psycopg2.extensions
import psycopg2.sql as psycopg_sql
elif utils.psycopg_maj_version == 3:
import psycopg.pq
import psycopg.sql as psycopg_sql
else:
raise AssertionError

if TYPE_CHECKING:
from django.db.backends.utils import CursorWrapper
from typing_extensions import TypeAlias

from pgtrigger import Timing

_Query: "TypeAlias" = "str | bytes | psycopg_sql.SQL | psycopg_sql.Composed"

# All triggers currently being ignored
_ignore = threading.local()
Expand All @@ -32,12 +38,23 @@
_schema = threading.local()


def _is_concurrent_statement(sql: str | bytes) -> bool:
def _query_to_str(query: _Query, cursor: CursorWrapper) -> str:
if isinstance(query, str):
return query
elif isinstance(query, bytes):
return query.decode()
elif isinstance(query, (psycopg_sql.SQL, psycopg_sql.Composed)):
return query.as_string(cursor.connection)
else: # pragma: no cover
raise TypeError(f"Unsupported query type: {type(query)}")


def _is_concurrent_statement(sql: _Query, cursor: CursorWrapper) -> bool:
"""
True if the sql statement is concurrent and cannot be ran in a transaction
"""
sql = _query_to_str(sql, cursor)
sql = sql.strip().lower() if sql else ""
sql = sql.decode() if isinstance(sql, bytes) else sql
return sql.startswith("create") and "concurrently" in sql


Expand Down Expand Up @@ -72,7 +89,7 @@ def _can_inject_variable(cursor, sql):
"""
return (
not getattr(cursor, "name", None)
and not _is_concurrent_statement(sql)
and not _is_concurrent_statement(sql, cursor)
and not _is_transaction_errored(cursor)
)

Expand All @@ -92,7 +109,7 @@ def _inject_pgtrigger_ignore(execute, sql, params, many, context):
"""
if _can_inject_variable(context["cursor"], sql):
serialized_ignore = "{" + ",".join(_ignore.value) + "}"
sql = sql.decode() if isinstance(sql, bytes) else sql
sql = _query_to_str(sql, context["cursor"])
sql = f"SELECT set_config('pgtrigger.ignore', %s, true); {sql}"
params = [serialized_ignore, *(params or ())]

Expand Down
28 changes: 27 additions & 1 deletion pgtrigger/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
import pgtrigger.utils
from pgtrigger.tests import models, utils

if pgtrigger.utils.psycopg_maj_version == 3:
from psycopg.sql import SQL, Identifier
elif pgtrigger.utils.psycopg_maj_version == 2:
from psycopg2.sql import SQL, Identifier
else:
raise AssertionError


@pytest.mark.django_db
def test_schema():
Expand Down Expand Up @@ -255,7 +262,26 @@ def test_inject_trigger_ignore(settings, mocker, sql, params):
with connection.cursor() as cursor:
cursor.execute(sql, params)
query = connection.queries[-1]

assert query["sql"].startswith(expected_sql_1) or query["sql"].startswith(
expected_sql_2
)


@pytest.mark.django_db
def test_test_trigger_ignore_psycopg_sql_objects():
"""Verify that native psycopg SQL objects are handled correctly when ignoring triggers."""
# Test with a `sql.SQL` object
with pgtrigger.ignore("tests.TestTrigger:protect_misc_insert"), connection.cursor() as cursor:
cursor.execute(
SQL(
"INSERT INTO tests_testtrigger (field, int_field, dt_field)"
"VALUES ('misc_insert', 1, now())",
)
)
# Test with a `sql.Composed` object (built through formatting)
with pgtrigger.ignore("tests.TestTrigger:protect_misc_insert"), connection.cursor() as cursor:
cursor.execute(
SQL(
"INSERT INTO {table} (field, int_field, dt_field) VALUES ('misc_insert', 1, now())"
).format(table=Identifier("tests_testtrigger"))
)

0 comments on commit f43c92e

Please sign in to comment.