Skip to content

Commit

Permalink
django-filter: added type extraction fallback for ChoiceFields #690
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Mar 28, 2022
1 parent bdb5a47 commit 6f5bd16
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
56 changes: 41 additions & 15 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
}
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)
filter_choices = self._get_explicit_filter_choices(filter_field)

if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
annotation = (
Expand Down Expand Up @@ -116,17 +117,24 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif isinstance(filter_field, filters.ChoiceFilter):
try:
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception:
if filter_choices and is_basic_type(type(filter_choices[0])):
# fallback to type guessing from first choice element
schema = build_basic_type(type(filter_choices[0]))
else:
warn(
f'Unable to guess choice types from values, filter method\'s type hint '
f'or find "{field_name}" in model. Defaulting to string.'
)
schema = build_basic_type(OpenApiTypes.STR)
else:
# the last resort is to look up the type via the model or queryset field
# and emit a warning if we were unsuccessful.
try:
# the last resort is to lookup the type via the model or queryset field.
# first search for the field in the model as this has the least amount of
# potential side effects. Only after that fails, attempt to call
# get_queryset() to check for potential query annotations.
model_field = self._get_model_field(filter_field, model)
if not isinstance(model_field, models.Field):
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
schema = auto_schema._map_model_field(model_field, direction=None)
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
Expand All @@ -139,12 +147,9 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
schema.pop('readOnly', None)
# enrich schema with additional info from filter_field
enum = schema.pop('enum', None)
if 'choices' in filter_field.extra:
if callable(filter_field.extra['choices']):
# choices function may utilize the DB, so refrain from actually calling it.
enum = None
else:
enum = [c for c, _ in filter_field.extra['choices']]
# explicit filter choices may disable enum retrieved from model
if filter_choices is not None:
enum = filter_choices
if enum:
schema['enum'] = sorted(enum, key=str)

Expand Down Expand Up @@ -207,8 +212,29 @@ def _get_filter_method_hint(self, filter_method):
except: # noqa: E722
return _NoHint

def _get_explicit_filter_choices(self, filter_field):
if 'choices' not in filter_field.extra:
return None
elif callable(filter_field.extra['choices']):
# choices function may utilize the DB, so refrain from actually calling it.
return []
else:
return [c for c, _ in filter_field.extra['choices']]

def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
return None
path = filter_field.field_name.split('__')
return follow_field_source(model, path, emit_warnings=False)

def _get_schema_from_model_field(self, auto_schema, filter_field, model):
# Has potential to throw exceptions. Needs to be wrapped in try/except!
#
# first search for the field in the model as this has the least amount of
# potential side effects. Only after that fails, attempt to call
# get_queryset() to check for potential query annotations.
model_field = self._get_model_field(filter_field, model)
if not isinstance(model_field, models.Field):
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
return auto_schema._map_model_field(model_field, direction=None)
7 changes: 7 additions & 0 deletions tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ class ProductFilter(FilterSet):

def get_choices(*args, **kwargs):
return (('A', 'aaa'),)

cat_callable = ChoiceFilter(field_name="category", choices=get_choices)

# will guess type from choices as a last resort
untyped_choice_field_method_with_explicit_choices = ChoiceFilter(
method="filter_method_untyped",
choices=[(1, 'one')],
)

class Meta:
model = Product
fields = [
Expand Down
6 changes: 6 additions & 0 deletions tests/contrib/test_django_filters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ paths:
schema:
type: number
format: float
- in: query
name: untyped_choice_field_method_with_explicit_choices
schema:
type: integer
enum:
- 1
tags:
- products
security:
Expand Down

0 comments on commit 6f5bd16

Please sign in to comment.