Skip to content

Commit

Permalink
Fix UniqueTogetherValidator with field sources (encode#7086)
Browse files Browse the repository at this point in the history
* Add failing tests for unique_together+source

* Fix UniqueTogetherValidator source handling

* Fix read-only+default+source handling

* Update test to use functional serializer

* Test UniqueTogetherValidator error+source
  • Loading branch information
rpkilby authored and Pierre Chiquet committed Mar 24, 2020
1 parent f96ba9c commit 9fc8963
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
2 changes: 1 addition & 1 deletion rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _read_only_defaults(self):
default = field.get_default()
except SkipField:
continue
defaults[field.field_name] = default
defaults[field.source] = default

return defaults

Expand Down
18 changes: 12 additions & 6 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def enforce_required_fields(self, attrs, serializer):
missing_items = {
field_name: self.missing_message
for field_name in self.fields
if field_name not in attrs
if serializer.fields[field_name].source not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
Expand All @@ -115,17 +115,23 @@ def filter_queryset(self, attrs, queryset, serializer):
"""
Filter the queryset to all instances matching the given attributes.
"""
# field names => field sources
sources = [
serializer.fields[field_name].source
for field_name in self.fields
]

# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if serializer.instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(serializer.instance, field_name)
for source in sources:
if source not in attrs:
attrs[source] = getattr(serializer.instance, source)

# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
field_name: attrs[field_name]
for field_name in self.fields
source: attrs[source]
for source in sources
}
return qs_filter(queryset, **filter_kwargs)

Expand Down
49 changes: 44 additions & 5 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,49 @@ class Meta:
]
}

def test_read_only_fields_with_default_and_source(self):
class ReadOnlySerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name', default='test', read_only=True)

class Meta:
model = UniquenessTogetherModel
fields = ['name', 'position']
validators = [
UniqueTogetherValidator(
queryset=UniquenessTogetherModel.objects.all(),
fields=['name', 'position']
)
]

serializer = ReadOnlySerializer(data={'position': 1})
assert serializer.is_valid(raise_exception=True)

def test_writeable_fields_with_source(self):
class WriteableSerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name')

class Meta:
model = UniquenessTogetherModel
fields = ['name', 'position']
validators = [
UniqueTogetherValidator(
queryset=UniquenessTogetherModel.objects.all(),
fields=['name', 'position']
)
]

serializer = WriteableSerializer(data={'name': 'test', 'position': 1})
assert serializer.is_valid(raise_exception=True)

# Validation error should use seriazlier field name, not source
serializer = WriteableSerializer(data={'position': 1})
assert not serializer.is_valid()
assert serializer.errors == {
'name': [
'This field is required.'
]
}

def test_allow_explict_override(self):
"""
Ensure validators can be explicitly removed..
Expand Down Expand Up @@ -357,13 +400,9 @@ class MockQueryset:
def filter(self, **kwargs):
self.called_with = kwargs

class MockSerializer:
def __init__(self, instance):
self.instance = instance

data = {'race_name': 'bar'}
queryset = MockQueryset()
serializer = MockSerializer(instance=self.instance)
serializer = UniquenessTogetherSerializer(instance=self.instance)
validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position'))
validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer)
Expand Down

0 comments on commit 9fc8963

Please sign in to comment.