Skip to content

Commit

Permalink
Fixed #29725 -- Removed unnecessary join in QuerySet.count() and exis…
Browse files Browse the repository at this point in the history
…ts() on a many to many relation.

Co-Authored-By: Shiwei Chen <[email protected]>
  • Loading branch information
2 people authored and felixxm committed Feb 16, 2024
1 parent 0d8fbe2 commit 66e47ac
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 10 deletions.
53 changes: 49 additions & 4 deletions django/db/models/fields/related_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Child(Model):
router,
transaction,
)
from django.db.models import Q, Window, signals
from django.db.models import Manager, Q, Window, signals
from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.query import QuerySet
Expand Down Expand Up @@ -1121,16 +1121,22 @@ def _apply_rel_filters(self, queryset):
queryset._defer_next_filter = True
return queryset._next_is_sticky().filter(**self.core_filters)

def get_prefetch_cache(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return None

def _remove_prefetched_objects(self):
try:
self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
except (AttributeError, KeyError):
pass # nothing to clear from cache

def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
if (cache := self.get_prefetch_cache()) is not None:
return cache
else:
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)

Expand Down Expand Up @@ -1195,6 +1201,45 @@ def get_prefetch_querysets(self, instances, querysets=None):
False,
)

@property
def constrained_target(self):
# If the through relation's target field's foreign integrity is
# enforced, the query can be performed solely against the through
# table as the INNER JOIN'ing against target table is unnecessary.
if not self.target_field.db_constraint:
return None
db = router.db_for_read(self.through, instance=self.instance)
if not connections[db].features.supports_foreign_keys:
return None
hints = {"instance": self.instance}
manager = self.through._base_manager.db_manager(db, hints=hints)
filters = {self.source_field_name: self.instance.pk}
# Nullable target rows must be excluded as well as they would have
# been filtered out from an INNER JOIN.
if self.target_field.null:
filters["%s__isnull" % self.target_field_name] = False
return manager.filter(**filters)

def exists(self):
if (
superclass is Manager
and self.get_prefetch_cache() is None
and (constrained_target := self.constrained_target) is not None
):
return constrained_target.exists()
else:
return super().exists()

def count(self):
if (
superclass is Manager
and self.get_prefetch_cache() is None
and (constrained_target := self.constrained_target) is not None
):
return constrained_target.count()
else:
return super().count()

def add(self, *objs, through_defaults=None):
self._remove_prefetched_objects()
db = router.db_for_write(self.through, instance=self.instance)
Expand Down
12 changes: 12 additions & 0 deletions tests/many_to_many/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ class InheritedArticleA(AbstractArticle):

class InheritedArticleB(AbstractArticle):
pass


class NullableTargetArticle(models.Model):
headline = models.CharField(max_length=100)
publications = models.ManyToManyField(
Publication, through="NullablePublicationThrough"
)


class NullablePublicationThrough(models.Model):
article = models.ForeignKey(NullableTargetArticle, models.CASCADE)
publication = models.ForeignKey(Publication, models.CASCADE, null=True)
96 changes: 90 additions & 6 deletions tests/many_to_many/tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from unittest import mock

from django.db import transaction
from django.db import connection, transaction
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango60Warning

from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User
from .models import (
Article,
InheritedArticleA,
InheritedArticleB,
NullablePublicationThrough,
NullableTargetArticle,
Publication,
User,
)


class ManyToManyTests(TestCase):
Expand Down Expand Up @@ -558,10 +566,16 @@ def test_inherited_models_selects(self):
def test_custom_default_manager_exists_count(self):
a5 = Article.objects.create(headline="deleted")
a5.publications.add(self.p2)
self.assertEqual(self.p2.article_set.count(), self.p2.article_set.all().count())
self.assertEqual(
self.p3.article_set.exists(), self.p3.article_set.all().exists()
)
with self.assertNumQueries(2) as ctx:
self.assertEqual(
self.p2.article_set.count(), self.p2.article_set.all().count()
)
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
with self.assertNumQueries(2) as ctx:
self.assertEqual(
self.p3.article_set.exists(), self.p3.article_set.all().exists()
)
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])

def test_get_prefetch_queryset_warning(self):
articles = Article.objects.all()
Expand All @@ -582,3 +596,73 @@ def test_get_prefetch_querysets_invalid_querysets_length(self):
instances=articles,
querysets=[Publication.objects.all(), Publication.objects.all()],
)


class ManyToManyQueryTests(TestCase):
"""
SQL is optimized to reference the through table without joining against the
related table when using count() and exists() functions on a queryset for
many to many relations. The optimization applies to the case where there
are no filters.
"""

@classmethod
def setUpTestData(cls):
cls.article = Article.objects.create(
headline="Django lets you build Web apps easily"
)
cls.nullable_target_article = NullableTargetArticle.objects.create(
headline="The python is good"
)
NullablePublicationThrough.objects.create(
article=cls.nullable_target_article, publication=None
)

@skipUnlessDBFeature("supports_foreign_keys")
def test_count_join_optimization(self):
with self.assertNumQueries(1) as ctx:
self.article.publications.count()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])

with self.assertNumQueries(1) as ctx:
self.article.publications.count()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
self.assertEqual(self.nullable_target_article.publications.count(), 0)

def test_count_join_optimization_disabled(self):
with (
mock.patch.object(connection.features, "supports_foreign_keys", False),
self.assertNumQueries(1) as ctx,
):
self.article.publications.count()

self.assertIn("JOIN", ctx.captured_queries[0]["sql"])

@skipUnlessDBFeature("supports_foreign_keys")
def test_exists_join_optimization(self):
with self.assertNumQueries(1) as ctx:
self.article.publications.exists()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])

self.article.publications.prefetch_related()
with self.assertNumQueries(1) as ctx:
self.article.publications.exists()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
self.assertIs(self.nullable_target_article.publications.exists(), False)

def test_exists_join_optimization_disabled(self):
with (
mock.patch.object(connection.features, "supports_foreign_keys", False),
self.assertNumQueries(1) as ctx,
):
self.article.publications.exists()

self.assertIn("JOIN", ctx.captured_queries[0]["sql"])

def test_prefetch_related_no_queries_optimization_disabled(self):
qs = Article.objects.prefetch_related("publications")
article = qs.get()
with self.assertNumQueries(0):
article.publications.count()
with self.assertNumQueries(0):
article.publications.exists()

0 comments on commit 66e47ac

Please sign in to comment.