Skip to content

Commit

Permalink
Introduce ValidationErrorMessage
Browse files Browse the repository at this point in the history
`ValidationErrorMessage` is a string-like object that holds a code
attribute.

The code attribute has been removed from ValidationError to be able
to maintain better backward compatibility.

What this means is that `ValidationError` can accept either a regular
string or a `ValidationErrorMessage` for its `detail` attribute.
  • Loading branch information
johnraz committed Apr 7, 2016
1 parent 6705a4f commit 2de95ce
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 70 deletions.
16 changes: 10 additions & 6 deletions rest_framework/authtoken/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework import serializers
from rest_framework.exceptions import ValidationErrorMessage


class AuthTokenSerializer(serializers.Serializer):
Expand All @@ -19,20 +20,23 @@ def validate(self, attrs):
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(
msg,
code='authorization'
ValidationErrorMessage(
msg,
code='authorization')
)

attrs['user'] = user
Expand Down
22 changes: 14 additions & 8 deletions rest_framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __str__(self):
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationError(msg, code=code)
ValidationErrorMessage(msg, code=code)
for msg in exc_info.messages
]

Expand All @@ -73,20 +73,26 @@ def build_error_from_django_validation_error(exc_info):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')

class ValidationErrorMessage(six.text_type):
code = None

def __new__(cls, string, code=None, *args, **kwargs):
self = super(ValidationErrorMessage, cls).__new__(
cls, string, *args, **kwargs)

self.code = code
return self


class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST

def __init__(self, detail, code=None):
def __init__(self, detail):
# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)):
assert code is None, (
'The `code` argument must not be set for compound errors.')

self.detail = detail
self.code = code
self.detail = _force_text_recursive(detail)

def __str__(self):
return six.text_type(self.detail)
Expand Down
7 changes: 4 additions & 3 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from rest_framework import ISO_8601
from rest_framework.compat import unicode_repr, unicode_to_repr
from rest_framework.exceptions import (
ValidationError, build_error_from_django_validation_error
ValidationError, ValidationErrorMessage,
build_error_from_django_validation_error
)
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation
Expand Down Expand Up @@ -505,7 +506,7 @@ def run_validators(self, value):
# attempting to accumulate a list of errors.
if isinstance(exc.detail, dict):
raise
errors.append(ValidationError(exc.detail, code=exc.code))
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(
build_error_from_django_validation_error(exc)
Expand Down Expand Up @@ -547,7 +548,7 @@ def fail(self, key, **kwargs):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
message_string = msg.format(**kwargs)
raise ValidationError(message_string, code=key)
raise ValidationError(ValidationErrorMessage(message_string, code=key))

@cached_property
def root(self):
Expand Down
1 change: 1 addition & 0 deletions rest_framework/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, data=None, status=None,
'`.error`. representation.'
)
raise AssertionError(msg)

self.data = data
self.template_name = template_name
self.exception = exception
Expand Down
53 changes: 9 additions & 44 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,7 @@ def is_valid(self, raise_exception=False):
self._errors = {}

if self._errors and raise_exception:
return_errors = None
if isinstance(self._errors, list):
return_errors = ReturnList(self._errors, serializer=self)
elif isinstance(self._errors, dict):
return_errors = ReturnDict(self._errors, serializer=self)

raise ValidationError(return_errors)

raise ValidationError(self.errors)
return not bool(self._errors)

@property
Expand All @@ -250,42 +243,12 @@ def data(self):
self._data = self.get_initial()
return self._data

def _transform_to_legacy_errors(self, errors_to_transform):
# Do not mutate `errors_to_transform` here.
errors = ReturnDict(serializer=self)
for field_name, values in errors_to_transform.items():
if isinstance(values, list):
errors[field_name] = values
continue

if isinstance(values.detail, list):
errors[field_name] = []
for value in values.detail:
if isinstance(value, ValidationError):
errors[field_name].extend(value.detail)
elif isinstance(value, list):
errors[field_name].extend(value)
else:
errors[field_name].append(value)

elif isinstance(values.detail, dict):
errors[field_name] = {}
for sub_field_name, value in values.detail.items():
errors[field_name][sub_field_name] = []
for validation_error in value:
errors[field_name][sub_field_name].extend(validation_error.detail)
return errors

@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)

if isinstance(self._errors, list):
return map(self._transform_to_legacy_errors, self._errors)
else:
return self._transform_to_legacy_errors(self._errors)
return self._errors

@property
def validated_data(self):
Expand Down Expand Up @@ -460,7 +423,7 @@ def to_internal_value(self, data):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
error = ValidationError(message, code='invalid')
error = ValidationErrorMessage(message, code='invalid')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [error]
})
Expand All @@ -477,7 +440,7 @@ def to_internal_value(self, data):
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
Expand Down Expand Up @@ -620,7 +583,7 @@ def to_internal_value(self, data):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
error = ValidationError(
error = ValidationErrorMessage(
message,
code='not_a_list'
)
Expand All @@ -629,9 +592,11 @@ def to_internal_value(self, data):
})

if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
message = ValidationErrorMessage(
self.error_messages['empty'],
code='empty_not_allowed')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')]
api_settings.NON_FIELD_ERRORS_KEY: [message]
})

ret = []
Expand Down
22 changes: 14 additions & 8 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ValidationError, ValidationErrorMessage
from rest_framework.utils.representation import smart_repr


Expand Down Expand Up @@ -60,7 +60,8 @@ def __call__(self, value):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if queryset.exists():
raise ValidationError(self.message, code='unique')
raise ValidationError(ValidationErrorMessage(self.message,
code='unique'))

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
Expand Down Expand Up @@ -101,9 +102,10 @@ def enforce_required_fields(self, attrs):
return

missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')

for field_name in self.fields
if field_name not in attrs
}
Expand Down Expand Up @@ -149,8 +151,11 @@ def __call__(self, attrs):
]
if None not in checked_values and queryset.exists():
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names),
code='unique')
raise ValidationError(
ValidationErrorMessage(
self.message.format(field_names=field_names),
code='unique')
)

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
Expand Down Expand Up @@ -188,7 +193,7 @@ def enforce_required_fields(self, attrs):
'required' state on the fields they are applied to.
"""
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field]
Expand Down Expand Up @@ -216,8 +221,9 @@ def __call__(self, attrs):
queryset = self.exclude_current_instance(attrs, queryset)
if queryset.exists():
message = self.message.format(date_field=self.date_field)
error = ValidationError(message, code='unique')
raise ValidationError({self.field: error})
raise ValidationError({
self.field: ValidationErrorMessage(message, code='unique'),
})

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
Expand Down
2 changes: 1 addition & 1 deletion rest_framework/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def exception_handler(exc, context):
headers['Retry-After'] = '%d' % exc.wait

if isinstance(exc.detail, (list, dict)):
data = exc.detail.serializer.errors
data = exc.detail
else:
data = {'detail': exc.detail}

Expand Down
74 changes: 74 additions & 0 deletions tests/test_validation_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from django.test import TestCase

from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView

factory = APIRequestFactory()


class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()


class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)


@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)


class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER

def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for error in errors:
return_errors[field_name].append({
'code': error.code,
'message': error
})

return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)

api_settings.EXCEPTION_HANDLER = exception_handler

self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}

def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER

def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()

request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)

def test_function_based_view_exception_handler(self):
view = error_view

request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)

0 comments on commit 2de95ce

Please sign in to comment.