Skip to content

Commit

Permalink
Fix serializer multiple inheritance bug (encode#6980)
Browse files Browse the repository at this point in the history
* Expand declared filtering tests

- Test declared filter ordering
- Test multiple inheritance

* Fix serializer multiple inheritance bug

* Improve field order test to check for field types
  • Loading branch information
rpkilby authored and sigvef committed Dec 3, 2022
1 parent d4ac9f6 commit d8a13b7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
28 changes: 16 additions & 12 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,22 @@ def _get_declared_fields(cls, bases, attrs):
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)

# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in reversed(bases):
if hasattr(base, '_declared_fields'):
fields = [
(field_name, obj) for field_name, obj
in base._declared_fields.items()
if field_name not in attrs
] + fields

return OrderedDict(fields)
# Ensures a base class field doesn't override cls attrs, and maintains
# field precedence when inheriting multiple parents. e.g. if there is a
# class C(A, B), and A and B both define 'field', use 'field' from A.
known = set(attrs)

def visit(name):
known.add(name)
return name

base_fields = [
(visit(name), f)
for base in bases if hasattr(base, '_declared_fields')
for name, f in base._declared_fields.items() if name not in known
]

return OrderedDict(base_fields + fields)

def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
Expand Down
50 changes: 50 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,53 @@ class Grandchild(Child):
assert len(Parent().get_fields()) == 2
assert len(Child().get_fields()) == 2
assert len(Grandchild().get_fields()) == 2

def test_multiple_inheritance(self):
class A(serializers.Serializer):
field = serializers.CharField()

class B(serializers.Serializer):
field = serializers.IntegerField()

class TestSerializer(A, B):
pass

fields = {
name: type(f) for name, f
in TestSerializer()._declared_fields.items()
}
assert fields == {
'field': serializers.CharField,
}

def test_field_ordering(self):
class Base(serializers.Serializer):
f1 = serializers.CharField()
f2 = serializers.CharField()

class A(Base):
f3 = serializers.IntegerField()

class B(serializers.Serializer):
f3 = serializers.CharField()
f4 = serializers.CharField()

class TestSerializer(A, B):
f2 = serializers.IntegerField()
f5 = serializers.CharField()

fields = {
name: type(f) for name, f
in TestSerializer()._declared_fields.items()
}

# `IntegerField`s should be the 'winners' in field name conflicts
# - `TestSerializer.f2` should override `Base.F2`
# - `A.f3` should override `B.f3`
assert fields == {
'f1': serializers.CharField,
'f2': serializers.IntegerField,
'f3': serializers.IntegerField,
'f4': serializers.CharField,
'f5': serializers.CharField,
}

0 comments on commit d8a13b7

Please sign in to comment.