From d8a13b7d55c50bd2d34bce031be7636effb99c4c Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 12 Dec 2019 06:03:55 -0800 Subject: [PATCH] Fix serializer multiple inheritance bug (#6980) * Expand declared filtering tests - Test declared filter ordering - Test multiple inheritance * Fix serializer multiple inheritance bug * Improve field order test to check for field types --- rest_framework/serializers.py | 28 +++++++++++--------- tests/test_serializer.py | 50 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 63fab3dc363..18f4d0df686 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -298,18 +298,22 @@ def _get_declared_fields(cls, bases, attrs): if isinstance(obj, Field)] fields.sort(key=lambda x: x[1]._creation_counter) - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in reversed(bases): - if hasattr(base, '_declared_fields'): - fields = [ - (field_name, obj) for field_name, obj - in base._declared_fields.items() - if field_name not in attrs - ] + fields - - return OrderedDict(fields) + # Ensures a base class field doesn't override cls attrs, and maintains + # field precedence when inheriting multiple parents. e.g. if there is a + # class C(A, B), and A and B both define 'field', use 'field' from A. + known = set(attrs) + + def visit(name): + known.add(name) + return name + + base_fields = [ + (visit(name), f) + for base in bases if hasattr(base, '_declared_fields') + for name, f in base._declared_fields.items() if name not in known + ] + + return OrderedDict(base_fields + fields) def __new__(cls, name, bases, attrs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index fab0472b941..a58c46b2d99 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -682,3 +682,53 @@ class Grandchild(Child): assert len(Parent().get_fields()) == 2 assert len(Child().get_fields()) == 2 assert len(Grandchild().get_fields()) == 2 + + def test_multiple_inheritance(self): + class A(serializers.Serializer): + field = serializers.CharField() + + class B(serializers.Serializer): + field = serializers.IntegerField() + + class TestSerializer(A, B): + pass + + fields = { + name: type(f) for name, f + in TestSerializer()._declared_fields.items() + } + assert fields == { + 'field': serializers.CharField, + } + + def test_field_ordering(self): + class Base(serializers.Serializer): + f1 = serializers.CharField() + f2 = serializers.CharField() + + class A(Base): + f3 = serializers.IntegerField() + + class B(serializers.Serializer): + f3 = serializers.CharField() + f4 = serializers.CharField() + + class TestSerializer(A, B): + f2 = serializers.IntegerField() + f5 = serializers.CharField() + + fields = { + name: type(f) for name, f + in TestSerializer()._declared_fields.items() + } + + # `IntegerField`s should be the 'winners' in field name conflicts + # - `TestSerializer.f2` should override `Base.F2` + # - `A.f3` should override `B.f3` + assert fields == { + 'f1': serializers.CharField, + 'f2': serializers.IntegerField, + 'f3': serializers.IntegerField, + 'f4': serializers.CharField, + 'f5': serializers.CharField, + }