Skip to content

Commit

Permalink
Add DurationField
Browse files Browse the repository at this point in the history
  • Loading branch information
ticosax committed Jun 1, 2015
1 parent a0f66ff commit ea2543a
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 4 deletions.
10 changes: 10 additions & 0 deletions docs/api-guide/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ Corresponds to `django.db.models.fields.TimeField`

Format strings may either be [Python strftime formats][strftime] which explicitly specify the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style times should be used. (eg `'12:34:56.000000'`)

## DurationField

A Duration representation.
Corresponds to `django.db.models.fields.Duration` for Django>=1.8
otherwise to `django.db.models.fields.BigIntegerField`.
The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`


**Signature:** `DurationField()`

---

# Choice selection fields
Expand Down
78 changes: 78 additions & 0 deletions rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

# flake8: noqa
from __future__ import unicode_literals
import datetime
from datetime import timedelta
import re
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from django.utils.encoding import force_text
Expand Down Expand Up @@ -258,3 +261,78 @@ def apply_markdown(text):
SHORT_SEPARATORS = (b',', b':')
LONG_SEPARATORS = (b', ', b': ')
INDENT_SEPARATORS = (b',', b': ')


if django.VERSION >= (1, 8):
from django.utils.dateparse import parse_duration
from django.utils.duration import duration_string
from django.db.models import DurationField
else:
from django.db.models import BigIntegerField

class DurationField(BigIntegerField):
def get_db_prep_value(self, value, connection, prepared=False):
if value is None:
return None
return total_seconds(value) * 1000000


# Backported from django 1.8
standard_duration_re = re.compile(
r'^'
r'(?:(?P<days>-?\d+) )?'
r'((?:(?P<hours>\d+):)(?=\d+:\d+))?'
r'(?:(?P<minutes>\d+):)?'
r'(?P<seconds>\d+)'
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
r'$'
)

# Support the sections of ISO 8601 date representation that are accepted by
# timedelta
iso8601_duration_re = re.compile(
r'^P'
r'(?:(?P<days>\d+(.\d+)?)D)?'
r'(?:T'
r'(?:(?P<hours>\d+(.\d+)?)H)?'
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
r')?'
r'$'
)

def parse_duration(value):
"""Parses a duration string and returns a datetime.timedelta.
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
Also supports ISO 8601 representation.
"""
match = standard_duration_re.match(value)
if not match:
match = iso8601_duration_re.match(value)
if match:
kw = match.groupdict()
if kw.get('microseconds'):
kw['microseconds'] = kw['microseconds'].ljust(6, unicode_to_repr('0'))
kw = dict((k, float(v)) for k, v in six.iteritems(kw) if v is not None)
return datetime.timedelta(**kw)

def duration_string(duration):
days = duration.days
seconds = duration.seconds
microseconds = duration.microseconds

minutes = seconds // 60
seconds = seconds % 60

hours = minutes // 60
minutes = minutes % 60

string = '{0:02d}:{1:02d}:{2:02d}'.format(hours, minutes, seconds)
if days:
string = '{0} '.format(days) + string
if microseconds:
string += '.{0:06d}'.format(microseconds)

return string
19 changes: 18 additions & 1 deletion rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rest_framework.compat import (
EmailValidator, MinValueValidator, MaxValueValidator,
MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict,
unicode_repr, unicode_to_repr
unicode_repr, unicode_to_repr, parse_duration, duration_string,
)
from rest_framework.exceptions import ValidationError
from rest_framework.settings import api_settings
Expand Down Expand Up @@ -1003,6 +1003,23 @@ def to_representation(self, value):
return value.strftime(self.format)


class DurationField(Field):
default_error_messages = {
'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'),
}

def to_internal_value(self, value):
if isinstance(value, datetime.timedelta):
return value
parsed = parse_duration(value)
if parsed is not None:
return parsed
self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]')

def to_representation(self, value):
return duration_string(value)


# Choice types...

class ChoiceField(Field):
Expand Down
8 changes: 6 additions & 2 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
from django.db.models import query
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.compat import (
postgres_fields,
unicode_to_repr,
DurationField as ModelDurationField,
)
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
get_url_kwargs, get_field_kwargs,
Expand All @@ -42,7 +46,6 @@
from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA


# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = (
Expand Down Expand Up @@ -716,6 +719,7 @@ class ModelSerializer(Serializer):
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
ModelDurationField: DurationField,
models.EmailField: EmailField,
models.Field: ModelField,
models.FileField: FileField,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from datetime import timedelta

import django
from django.db import connection
import pytest


@pytest.mark.skipif(django.VERSION >= (1, 8),
reason='Django1.8+ have native DurationField')
def test_duration_field():
from rest_framework.compat import DurationField
delta = timedelta(seconds=1)
assert DurationField().get_db_prep_value(delta, connection) == 1 * 1e6
20 changes: 20 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,26 @@ class TestNoOutputFormatTimeField(FieldValues):
field = serializers.TimeField(format=None)


class TestDurationField(FieldValues):
"""
Valid and invalid values for `DurationField`.
"""
valid_inputs = {
'13': datetime.timedelta(seconds=13),
'3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
'08:01': datetime.timedelta(minutes=8, seconds=1),
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
}
invalid_inputs = {
'abc': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
'3 08:32 01.123': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
}
outputs = {
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): '3 08:32:01.000123',
}
field = serializers.DurationField()


# Choice types...

class TestChoiceField(FieldValues):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.test import TestCase
from django.utils import six
from rest_framework import serializers
from rest_framework.compat import unicode_repr
from rest_framework.compat import unicode_repr, DurationField as ModelDurationField


def dedent(blocktext):
Expand Down Expand Up @@ -45,6 +45,7 @@ class RegularFieldsModel(models.Model):
date_field = models.DateField()
datetime_field = models.DateTimeField()
decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
duration_field = ModelDurationField()
email_field = models.EmailField(max_length=100)
float_field = models.FloatField()
integer_field = models.IntegerField()
Expand Down Expand Up @@ -138,6 +139,7 @@ class Meta:
date_field = DateField()
datetime_field = DateTimeField()
decimal_field = DecimalField(decimal_places=1, max_digits=3)
duration_field = DurationField()
email_field = EmailField(max_length=100)
float_field = FloatField()
integer_field = IntegerField()
Expand Down

0 comments on commit ea2543a

Please sign in to comment.