Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass a context in depth of serialize calls (API changed) #41

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Complex Example
x = serpy.Field(call=True)
plus = serpy.MethodField()

def get_plus(self, obj):
def get_plus(self, obj, context):
return obj.y + obj.z

f = Foo()
Expand Down Expand Up @@ -170,6 +170,32 @@ Inheritance Example
ABSerializer(f).data
# {'a': 1, 'b': 2}

Context Example
---------------

Context is just an object passed to each getter. It can be of any type.

.. code-block:: python

import serpy

class ContextedField(serpy.Field):

def to_value(self, value, context):
return context.do_something(value)


class ASerializer(serpy.Serializer):
a = serpy.ContextedField()
b = serpy.MethodVield()

def get_b(self, value, context):
user = context.request.user
return value if user.is_authenticated else None

f = Foo()
ASerializer(f, context=context).data

License
=======
serpy is free software distributed under the terms of the MIT license. See the
Expand Down
4 changes: 2 additions & 2 deletions docs/custom-fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ adds 5 to every value it serializes, do:
.. code-block:: python

class Add5Field(serpy.Field):
def to_value(self, value):
def to_value(self, value, context):
return value + 5

Then to use it:
Expand All @@ -34,7 +34,7 @@ every serialized value has a ``'.'`` in it:
.. code-block:: python

class ValidateDotField(serpy.Field):
def to_value(self, value):
def to_value(self, value, context):
if '.' not in value:
raise ValidationError('no dot!')
return value
Expand Down
27 changes: 20 additions & 7 deletions serpy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, attr=None, call=False, label=None, required=True):
self.label = label
self.required = required

def to_value(self, value):
def to_value(self, value, context):
"""Transform the serialized value.

Override this method to clean and validate values serialized by this
Expand All @@ -41,6 +41,7 @@ def to_value(self, value):
return int(value)

:param value: The value fetched from the object being serialized.
:param context: A context received from caller.
"""
return value
to_value._serpy_base_implementation = True
Expand Down Expand Up @@ -78,22 +79,34 @@ def as_getter(self, serializer_field_name, serializer_cls):

class StrField(Field):
"""A :class:`Field` that converts the value to a string."""
to_value = staticmethod(six.text_type)

@staticmethod
def to_value(value, context):
return six.text_type(value)


class IntField(Field):
"""A :class:`Field` that converts the value to an integer."""
to_value = staticmethod(int)

@staticmethod
def to_value(value, context):
return int(value)


class FloatField(Field):
"""A :class:`Field` that converts the value to a float."""
to_value = staticmethod(float)

@staticmethod
def to_value(value, context):
return float(value)


class BoolField(Field):
"""A :class:`Field` that converts the value to a boolean."""
to_value = staticmethod(bool)

@staticmethod
def to_value(value, context):
return bool(value)


class MethodField(Field):
Expand All @@ -106,10 +119,10 @@ class FooSerializer(Serializer):
plus = MethodField()
minus = MethodField('do_minus')

def get_plus(self, foo_obj):
def get_plus(self, foo_obj, context):
return foo_obj.bar + foo_obj.baz

def do_minus(self, foo_obj):
def do_minus(self, foo_obj, context):
return foo_obj.bar - foo_obj.baz

foo = Foo(bar=5, baz=10)
Expand Down
21 changes: 12 additions & 9 deletions serpy/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class FooSerializer(Serializer):
:param instance: The object or objects to serialize.
:param bool many: If ``instance`` is a collection of objects, set ``many``
to ``True`` to serialize to a list.
:param context: Currently unused parameter for compatability with Django
REST Framework serializers.
:param context: Parameter for compatibility with Django
REST Framework serializers. Is passed as second argument to ``to_value``
and getters methods.
"""
#: The default getter used if :meth:`Field.as_getter` returns None.
default_getter = operator.attrgetter
Expand All @@ -97,31 +98,33 @@ def __init__(self, instance=None, many=False, data=None, context=None,

super(Serializer, self).__init__(**kwargs)
self.instance = instance
self.context = context
self.many = many
self._data = None

def _serialize(self, instance, fields):
def _serialize(self, instance, fields, context):
v = {}
for name, getter, to_value, call, required, pass_self in fields:
if pass_self:
result = getter(self, instance)
# MethodField
result = getter(self, instance, context)
else:
result = getter(instance)
if required or result is not None:
if call:
result = result()
if to_value:
result = to_value(result)
result = to_value(result, context)
v[name] = result

return v

def to_value(self, instance):
def to_value(self, instance, context):
fields = self._compiled_fields
if self.many:
serialize = self._serialize
return [serialize(o, fields) for o in instance]
return self._serialize(instance, fields)
return [serialize(o, fields, context) for o in instance]
return self._serialize(instance, fields, context)

@property
def data(self):
Expand All @@ -131,7 +134,7 @@ def data(self):
"""
# Cache the data for next time .data is called.
if self._data is None:
self._data = self.to_value(self.instance)
self._data = self.to_value(self.instance, self.context)
return self._data


Expand Down
30 changes: 15 additions & 15 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
class TestFields(unittest.TestCase):

def test_to_value_noop(self):
self.assertEqual(Field().to_value(5), 5)
self.assertEqual(Field().to_value('a'), 'a')
self.assertEqual(Field().to_value(None), None)
self.assertEqual(Field().to_value(5, None), 5)
self.assertEqual(Field().to_value('a', None), 'a')
self.assertEqual(Field().to_value(None, None), None)

def test_as_getter_none(self):
self.assertEqual(Field().as_getter(None, None), None)

def test_is_to_value_overridden(self):
class TransField(Field):
def to_value(self, value):
def to_value(self, value, context):
return value

field = Field()
Expand All @@ -28,26 +28,26 @@ def to_value(self, value):

def test_str_field(self):
field = StrField()
self.assertEqual(field.to_value('a'), 'a')
self.assertEqual(field.to_value(5), '5')
self.assertEqual(field.to_value('a', None), 'a')
self.assertEqual(field.to_value(5, None), '5')

def test_bool_field(self):
field = BoolField()
self.assertTrue(field.to_value(True))
self.assertFalse(field.to_value(False))
self.assertTrue(field.to_value(1))
self.assertFalse(field.to_value(0))
self.assertTrue(field.to_value(True, None))
self.assertFalse(field.to_value(False, None))
self.assertTrue(field.to_value(1, None))
self.assertFalse(field.to_value(0, None))

def test_int_field(self):
field = IntField()
self.assertEqual(field.to_value(5), 5)
self.assertEqual(field.to_value(5.4), 5)
self.assertEqual(field.to_value('5'), 5)
self.assertEqual(field.to_value(5, None), 5)
self.assertEqual(field.to_value(5.4, None), 5)
self.assertEqual(field.to_value('5', None), 5)

def test_float_field(self):
field = FloatField()
self.assertEqual(field.to_value(5.2), 5.2)
self.assertEqual(field.to_value('5.5'), 5.5)
self.assertEqual(field.to_value(5.2, None), 5.2)
self.assertEqual(field.to_value('5.5', None), 5.5)

def test_method_field(self):
class FakeSerializer(object):
Expand Down
16 changes: 9 additions & 7 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,20 @@ class BSerializer(Serializer):
self.assertEqual(BSerializer(b).data['b']['a'], 3)

def test_serializer_method_field(self):
context = {5: 5, 9: 9}

class ASerializer(Serializer):
a = MethodField()
b = MethodField('add_9')

def get_a(self, obj):
return obj.a + 5
def get_a(self, obj, context):
return obj.a + context[5]

def add_9(self, obj):
return obj.a + 9
def add_9(self, obj, context):
return obj.a + context[9]

a = Obj(a=2)
data = ASerializer(a).data
data = ASerializer(a, context=context).data
self.assertEqual(data['a'], 7)
self.assertEqual(data['b'], 11)

Expand Down Expand Up @@ -142,7 +144,7 @@ class ASerializer(Serializer):

def test_custom_field(self):
class Add5Field(Field):
def to_value(self, value):
def to_value(self, value, context):
return value + 5

class ASerializer(Serializer):
Expand Down Expand Up @@ -178,7 +180,7 @@ class ASerializer(Serializer):
context = StrField(label="@context")
content = MethodField(label="@content")

def get_content(self, obj):
def get_content(self, obj, context):
return obj.content

o = Obj(context="http://foo/bar/baz/", content="http://baz/bar/foo/")
Expand Down