Skip to content

Commit

Permalink
Account for bytes SQL (#171)
Browse files Browse the repository at this point in the history
* Account for bytes sql

* revert

* Simplify

* Fix

* Add test
  • Loading branch information
max-muoto authored Sep 8, 2024
1 parent 487ba68 commit 6eaa768
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pgtrigger/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
_schema = threading.local()


def _is_concurrent_statement(sql):
def _is_concurrent_statement(sql: str | bytes) -> bool:
"""
True if the sql statement is concurrent and cannot be ran in a transaction
"""
sql = sql.strip().lower() if sql else ""
sql = sql.decode() if isinstance(sql, bytes) else sql
return sql.startswith("create") and "concurrently" in sql


Expand Down Expand Up @@ -91,6 +92,7 @@ def _inject_pgtrigger_ignore(execute, sql, params, many, context):
"""
if _can_inject_variable(context["cursor"], sql):
serialized_ignore = "{" + ",".join(_ignore.value) + "}"
sql = sql.decode() if isinstance(sql, bytes) else sql
sql = f"SELECT set_config('pgtrigger.ignore', %s, true); {sql}"
params = [serialized_ignore, *(params or ())]

Expand Down
34 changes: 34 additions & 0 deletions pgtrigger/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import IntegrityError, connection, transaction

import pgtrigger
import pgtrigger.utils
from pgtrigger.tests import models, utils


Expand Down Expand Up @@ -225,3 +226,36 @@ def test_custom_db_table_ignore():
with pgtrigger.ignore("tests.CustomTableName:protect_delete"):
deletion_protected_model.delete()
assert not models.CustomTableName.objects.exists()


@pytest.mark.skipif(
pgtrigger.utils.psycopg_maj_version == 3, reason="Psycopg2 preserves entire query"
)
@pytest.mark.django_db
@pytest.mark.parametrize(
"sql, params",
[
("select count(*) from auth_user where id = %s", (1,)),
("select count(*) from auth_user", ()),
(b"select count(*) from auth_user where id = %s", (1,)),
(b"select count(*) from auth_user", ()),
],
)
def test_inject_trigger_ignore(settings, mocker, sql, params):
settings.DEBUG = True
expected_sql_base = "SELECT set_config('pgtrigger.ignore', '{ignored_triggers}', true)"
# Order isn't deterministic, so we need to check for either order.
expected_sql_1 = expected_sql_base.format(
ignored_triggers=r"{tests_testtrigger:pgtrigger_protect_delete_b7483,pgtrigger_protect_delete_b7483}"
)
expected_sql_2 = expected_sql_base.format(
ignored_triggers=r"{pgtrigger_protect_delete_b7483,tests_testtrigger:pgtrigger_protect_delete_b7483}"
)
with pgtrigger.ignore("tests.TestTriggerProxy:protect_delete"):
with connection.cursor() as cursor:
cursor.execute(sql, params)
query = connection.queries[-1]

assert query["sql"].startswith(expected_sql_1) or query["sql"].startswith(
expected_sql_2
)

0 comments on commit 6eaa768

Please sign in to comment.