From 3b76653111dc1db43c366717f3285502cb6ee31e Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 2 Dec 2024 11:54:50 +0000 Subject: [PATCH] Fix check_constraint_condition fixer for GIS models module Fixes #513. --- CHANGELOG.rst | 4 +++ .../fixers/check_constraint_condition.py | 11 +++++-- .../fixers/test_check_constraint_condition.py | 30 +++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9078e090..e69ed3c4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ Changelog ========= +* Fix ``check_constraint_condition`` fixer to work when ``django.contrib.gis.db.models`` is used to import ``CheckConstraint``. + + `Issue #513 `__. + 1.22.1 (2024-10-11) ------------------- diff --git a/src/django_upgrade/fixers/check_constraint_condition.py b/src/django_upgrade/fixers/check_constraint_condition.py index 8dc9acc1..7135cbaf 100644 --- a/src/django_upgrade/fixers/check_constraint_condition.py +++ b/src/django_upgrade/fixers/check_constraint_condition.py @@ -34,14 +34,21 @@ def visit_Call( ( isinstance(node.func, ast.Name) and node.func.id == "CheckConstraint" - and "CheckConstraint" in state.from_imports["django.db.models"] + and ( + "CheckConstraint" in state.from_imports["django.db.models"] + or "CheckConstraint" + in state.from_imports["django.contrib.gis.db.models"] + ) ) or ( isinstance(node.func, ast.Attribute) and node.func.attr == "CheckConstraint" and isinstance(node.func.value, ast.Name) and node.func.value.id == "models" - and "models" in state.from_imports["django.db"] + and ( + "models" in state.from_imports["django.db"] + or "models" in state.from_imports["django.contrib.gis.db"] + ) ) ) and (kwarg_names := {k.arg for k in node.keywords}) diff --git a/tests/fixers/test_check_constraint_condition.py b/tests/fixers/test_check_constraint_condition.py index 204805ca..000a7e5b 100644 --- a/tests/fixers/test_check_constraint_condition.py +++ b/tests/fixers/test_check_constraint_condition.py @@ -86,6 +86,21 @@ def test_success_name(): ) +def test_success_name_gis(): + check_transformed( + """\ + from django.contrib.gis.db.models import CheckConstraint + + CheckConstraint(check=Q(id=1)) + """, + """\ + from django.contrib.gis.db.models import CheckConstraint + + CheckConstraint(condition=Q(id=1)) + """, + ) + + def test_success_attr(): check_transformed( """\ @@ -101,6 +116,21 @@ def test_success_attr(): ) +def test_success_attr_gis(): + check_transformed( + """\ + from django.contrib.gis.db import models + + models.CheckConstraint(check=models.Q(id=1)) + """, + """\ + from django.contrib.gis.db import models + + models.CheckConstraint(condition=models.Q(id=1)) + """, + ) + + def test_success_other_args(): check_transformed( """\