diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 8a295c03e9..42d67c2d3c 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,17 @@ def validate(self, attrs): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + code = 'authorization' + raise serializers.ValidationError(msg, code=code) else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + code = 'authorization' + raise serializers.ValidationError(msg, code=code) + else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + code = 'authorization' + raise serializers.ValidationError(msg, code=code) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 8447a9dedc..5f1b27c766 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -7,6 +7,7 @@ from __future__ import unicode_literals import math +from collections import namedtuple from django.utils import six from django.utils.encoding import force_text @@ -58,6 +59,13 @@ def __str__(self): return self.detail +def build_error_from_django_validation_error(exc_info): + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ErrorDetails(msg, code) + for msg in exc_info.messages + ] + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -65,15 +73,61 @@ def __str__(self): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') +ErrorDetails = namedtuple('ErrorDetails', ['message', 'code']) + + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST + code = None - def __init__(self, detail): - # For validation errors the 'detail' key is always required. - # The details should always be coerced to a list if not already. - if not isinstance(detail, dict) and not isinstance(detail, list): - detail = [detail] - self.detail = _force_text_recursive(detail) + def __init__(self, detail, code=None): + if code: + self.full_details = ErrorDetails(detail, code) + else: + self.full_details = detail + + if not isinstance(self.full_details, dict) \ + and not isinstance(self.full_details, list): + self.full_details = [self.full_details] + self.full_details = _force_text_recursive(self.full_details) + + self.detail = detail + if isinstance(self.full_details, list): + if isinstance(self.full_details, ReturnList): + self.detail = ReturnList( + serializer=self.full_details.serializer) + else: + self.detail = [] + for full_detail in self.full_details: + if isinstance(full_detail, ErrorDetails): + self.detail.append(full_detail.message) + elif isinstance(full_detail, dict): + if not full_detail: + self.detail.append(full_detail) + for key, value in full_detail.items(): + if isinstance(value, list): + self.detail.append( + {key: [item.message] + if isinstance(item, ErrorDetails) + else [item] for item in value}) + elif isinstance(full_detail, list): + self.detail.extend(full_detail) + else: + self.detail.append(full_detail) + elif isinstance(self.full_details, dict): + if isinstance(self.full_details, ReturnDict): + self.detail = ReturnDict( + serializer=self.full_details.serializer) + else: + self.detail = {} + for field_name, full_detail in self.full_details.items(): + if isinstance(full_detail, list): + self.detail[field_name] = [ + item.message if isinstance(item, ErrorDetails) else item + for item in full_detail + ] + else: + self.detail[field_name] = full_detail def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8541bc43a0..1f588e990a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -31,7 +31,9 @@ MinValueValidator, duration_string, parse_duration, unicode_repr, unicode_to_repr ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -501,9 +503,9 @@ def run_validators(self, value): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.extend(exc.detail) + errors.append(exc.full_details) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend(build_error_from_django_validation_error(exc)) if errors: raise ValidationError(errors) @@ -541,7 +543,7 @@ def fail(self, key, **kwargs): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 99d36a8a54..77c5cd182f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,6 +23,9 @@ from rest_framework.compat import DurationField as ModelDurationField from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.exceptions import ( + ErrorDetails, build_error_from_django_validation_error +) from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -299,8 +302,9 @@ def get_validation_error_detail(exc): # inside your codebase, but we handle Django's validation # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() + error = build_error_from_django_validation_error(exc) return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: error } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -422,8 +426,9 @@ def to_internal_value(self, data): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + code = 'invalid' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) ret = OrderedDict() @@ -440,7 +445,8 @@ def to_internal_value(self, data): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + error = build_error_from_django_validation_error(exc) + errors[field.field_name] = error except SkipField: pass else: @@ -575,14 +581,16 @@ def to_internal_value(self, data): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + code = 'not_a_list' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] + code = 'empty_not_allowed' raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetails(message, code)] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index a21f67e60e..07c8eb464e 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -11,7 +11,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ErrorDetails, ValidationError from rest_framework.utils.representation import smart_repr @@ -60,7 +60,7 @@ def __call__(self, value): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if queryset.exists(): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -100,8 +100,9 @@ def enforce_required_fields(self, attrs): if self.instance is not None: return + code = 'required' missing = { - field_name: self.missing_message + field_name: ErrorDetails(self.missing_message, code) for field_name in self.fields if field_name not in attrs } @@ -147,7 +148,9 @@ def __call__(self, attrs): ] if None not in checked_values and queryset.exists(): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + message = self.message.format(field_names=field_names) + code = 'unique' + raise ValidationError(ErrorDetails(message, code=code)) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -184,8 +187,9 @@ def enforce_required_fields(self, attrs): The `UniqueForValidator` classes always force an implied 'required' state on the fields they are applied to. """ + code = 'required' missing = { - field_name: self.missing_message + field_name: ErrorDetails(self.missing_message, code) for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -211,7 +215,8 @@ def __call__(self, attrs): queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + code = 'unique' + raise ValidationError({self.field: ErrorDetails(message, code)}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 0000000000..7d2ec1f8d3 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,74 @@ +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithCode(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + return_errors = {} + for field_name, errors in exc.detail.items(): + return_errors[field_name] = [] + for message, code in errors: + return_errors[field_name].append({ + 'code': code, + 'message': message + }) + + return Response(return_errors, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': [{ + 'message': 'This field is required.', + 'code': 'required', + }], + 'integer': [{ + 'message': 'This field is required.', + 'code': 'required' + }], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data)