Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto committed Sep 8, 2024
1 parent cddc76a commit e9f0b50
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
1 change: 1 addition & 0 deletions pgtrigger/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _inject_pgtrigger_ignore(execute, sql, params, many, context):
"""
if _can_inject_variable(context["cursor"], sql):
serialized_ignore = "{" + ",".join(_ignore.value) + "}"
sql = sql.decode() if isinstance(sql, bytes) else sql
sql = f"SELECT set_config('pgtrigger.ignore', %s, true); {sql}"
params = [serialized_ignore, *(params or ())]

Expand Down
34 changes: 34 additions & 0 deletions pgtrigger/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import IntegrityError, connection, transaction

import pgtrigger
import pgtrigger.utils
from pgtrigger.tests import models, utils


Expand Down Expand Up @@ -225,3 +226,36 @@ def test_custom_db_table_ignore():
with pgtrigger.ignore("tests.CustomTableName:protect_delete"):
deletion_protected_model.delete()
assert not models.CustomTableName.objects.exists()


@pytest.mark.skipif(
pgtrigger.utils.psycopg_maj_version == 3, reason="Psycopg2 preserves entire query"
)
@pytest.mark.django_db
@pytest.mark.parametrize(
"sql, params",
[
("select count(*) from auth_user where id = %s", (1,)),
("select count(*) from auth_user", ()),
(b"select count(*) from auth_user where id = %s", (1,)),
(b"select count(*) from auth_user", ()),
],
)
def test_inject_trigger_ignore(settings, mocker, sql, params):
settings.DEBUG = True
expected_sql_base = "SELECT set_config('pgtrigger.ignore', '{ignored_triggers}', true)"
# Order isn't deterministic, so we need to check for either order.
expected_sql_1 = expected_sql_base.format(
ignored_triggers=r"{tests_testtrigger:pgtrigger_protect_delete_b7483,pgtrigger_protect_delete_b7483}"
)
expected_sql_2 = expected_sql_base.format(
ignored_triggers=r"{pgtrigger_protect_delete_b7483,tests_testtrigger:pgtrigger_protect_delete_b7483}"
)
with pgtrigger.ignore("tests.TestTriggerProxy:protect_delete"):
with connection.cursor() as cursor:
cursor.execute(sql, params)
query = connection.queries[-1]

assert query["sql"].startswith(expected_sql_1) or query["sql"].startswith(
expected_sql_2
)

0 comments on commit e9f0b50

Please sign in to comment.