From df0d814665f3b0dab5e3782363abca2b89160e7f Mon Sep 17 00:00:00 2001 From: Jonathan Liuti Date: Mon, 7 Dec 2015 10:48:53 +0100 Subject: [PATCH 1/2] Introduce error code for validation errors. This patch is meant to fix #3111, regarding comments made to #3137 and #3169. The `ValidationError` will now contain a `code` attribute. Before this patch, `ValidationError.detail` only contained a `dict` with values equal to a `list` of string error messages or directly a `list` containing string error messages. Now, the string error messages are replaced with `ValidationError`. This means that, depending on the case, you will not only get a string back but a all object containing both the error message and the error code, respectively `ValidationError.detail` and `ValidationError.code`. It is important to note that the `code` attribute is not relevant when the `ValidationError` represents a combination of errors and hence is `None` in such cases. The main benefit of this change is that the error message and error code are now accessible the custom exception handler and can be used to format the error response. An custom exception handler example is available in the `TestValidationErrorWithCode` test. We keep `Serializer.errors`'s return type unchanged in order to maintain backward compatibility. The error codes will only be propagated to the `exception_handler` or accessible through the `Serializer._errors` private attribute. --- rest_framework/authtoken/serializers.py | 15 ++++-- rest_framework/exceptions.py | 17 ++++++- rest_framework/fields.py | 12 +++-- rest_framework/response.py | 1 - rest_framework/serializers.py | 61 +++++++++++++++++++++---- rest_framework/validators.py | 16 +++++-- rest_framework/views.py | 2 +- 7 files changed, 100 insertions(+), 24 deletions(-) diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86a..ce6bb9c79a 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,22 @@ def validate(self, attrs): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization' + ) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 29afaffe00..e23b7cd315 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -58,6 +58,14 @@ def __str__(self): return self.detail +def build_error_from_django_validation_error(exc_info): + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ValidationError(msg, code=code) + for msg in exc_info.messages + ] + + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -68,12 +76,17 @@ def __str__(self): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + def __init__(self, detail, code=None): # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - self.detail = _force_text_recursive(detail) + elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)): + assert code is None, ( + 'The `code` argument must not be set for compound errors.') + + self.detail = detail + self.code = code def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f76e4e8011..39a5e33955 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -34,7 +34,9 @@ from rest_framework.compat import ( get_remote_field, unicode_repr, unicode_to_repr, value_from_object ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -507,9 +509,11 @@ def run_validators(self, value): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.extend(exc.detail) + errors.append(ValidationError(exc.detail, code=exc.code)) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend( + build_error_from_django_validation_error(exc) + ) if errors: raise ValidationError(errors) @@ -547,7 +551,7 @@ def fail(self, key, **kwargs): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/response.py b/rest_framework/response.py index 4b863cb997..e9ceb27419 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -38,7 +38,6 @@ def __init__(self, data=None, status=None, '`.error`. representation.' ) raise AssertionError(msg) - self.data = data self.template_name = template_name self.exception = exception diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63aef..a4dbc64492 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -22,6 +22,7 @@ from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.utils import model_meta @@ -219,7 +220,13 @@ def is_valid(self, raise_exception=False): self._errors = {} if self._errors and raise_exception: - raise ValidationError(self.errors) + return_errors = None + if isinstance(self._errors, list): + return_errors = ReturnList(self._errors, serializer=self) + elif isinstance(self._errors, dict): + return_errors = ReturnDict(self._errors, serializer=self) + + raise ValidationError(return_errors) return not bool(self._errors) @@ -244,12 +251,42 @@ def data(self): self._data = self.get_initial() return self._data + def _transform_to_legacy_errors(self, errors_to_transform): + # Do not mutate `errors_to_transform` here. + errors = ReturnDict(serializer=self) + for field_name, values in errors_to_transform.items(): + if isinstance(values, list): + errors[field_name] = values + continue + + if isinstance(values.detail, list): + errors[field_name] = [] + for value in values.detail: + if isinstance(value, ValidationError): + errors[field_name].extend(value.detail) + elif isinstance(value, list): + errors[field_name].extend(value) + else: + errors[field_name].append(value) + + elif isinstance(values.detail, dict): + errors[field_name] = {} + for sub_field_name, value in values.detail.items(): + errors[field_name][sub_field_name] = [] + for validation_error in value: + errors[field_name][sub_field_name].extend(validation_error.detail) + return errors + @property def errors(self): if not hasattr(self, '_errors'): msg = 'You must call `.is_valid()` before accessing `.errors`.' raise AssertionError(msg) - return self._errors + + if isinstance(self._errors, list): + return map(self._transform_to_legacy_errors, self._errors) + else: + return self._transform_to_legacy_errors(self._errors) @property def validated_data(self): @@ -301,7 +338,8 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: + exceptions.build_error_from_django_validation_error(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -423,8 +461,9 @@ def to_internal_value(self, data): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + error = ValidationError(message, code='invalid') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = OrderedDict() @@ -439,9 +478,11 @@ def to_internal_value(self, data): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc.detail + errors[field.field_name] = exc except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = ( + exceptions.build_error_from_django_validation_error(exc) + ) except SkipField: pass else: @@ -580,14 +621,18 @@ def to_internal_value(self, data): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + error = ValidationError( + message, + code='not_a_list' + ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index ef23b9bd70..90483eeeb9 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -79,7 +79,7 @@ def __call__(self, value): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if qs_exists(queryset): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -120,7 +120,9 @@ def enforce_required_fields(self, attrs): return missing = { - field_name: self.missing_message + field_name: ValidationError( + self.missing_message, + code='required') for field_name in self.fields if field_name not in attrs } @@ -166,7 +168,8 @@ def __call__(self, attrs): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + raise ValidationError(self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -204,7 +207,9 @@ def enforce_required_fields(self, attrs): 'required' state on the fields they are applied to. """ missing = { - field_name: self.missing_message + field_name: ValidationError( + self.missing_message, + code='required') for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -230,7 +235,8 @@ def __call__(self, attrs): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + error = ValidationError(message, code='unique') + raise ValidationError({self.field: error}) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/rest_framework/views.py b/rest_framework/views.py index 15d8c6cde2..8b6f060d46 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -71,7 +71,7 @@ def exception_handler(exc, context): headers['Retry-After'] = '%d' % exc.wait if isinstance(exc.detail, (list, dict)): - data = exc.detail + data = exc.detail.serializer.errors else: data = {'detail': exc.detail} From 2bf6ee47f3987a39d83d32a1105f524fe8e72728 Mon Sep 17 00:00:00 2001 From: Jonathan Liuti Date: Wed, 16 Dec 2015 19:09:03 +0100 Subject: [PATCH 2/2] Introduce ValidationErrorMessage `ValidationErrorMessage` is a string-like object that holds a code attribute. The code attribute has been removed from ValidationError to be able to maintain better backward compatibility. `ValidationErrorMessage` is abstracted in `ValidationError`'s constructor --- rest_framework/authtoken/serializers.py | 10 ++-- rest_framework/exceptions.py | 24 +++++--- rest_framework/fields.py | 2 +- rest_framework/response.py | 1 + rest_framework/serializers.py | 54 ++++-------------- rest_framework/validators.py | 17 +++--- rest_framework/views.py | 2 +- tests/test_validation_error.py | 74 +++++++++++++++++++++++++ 8 files changed, 118 insertions(+), 66 deletions(-) create mode 100644 tests/test_validation_error.py diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index ce6bb9c79a..abaac0c223 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -20,20 +20,18 @@ def validate(self, attrs): msg = _('User account is disabled.') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') else: msg = _('Unable to log in with provided credentials.') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') + else: msg = _('Must include "username" and "password".') raise serializers.ValidationError( msg, - code='authorization' - ) + code='authorization') attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index e23b7cd315..6e30834e6a 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -61,7 +61,7 @@ def __str__(self): def build_error_from_django_validation_error(exc_info): code = getattr(exc_info, 'code', None) or 'invalid' return [ - ValidationError(msg, code=code) + ValidationErrorMessage(msg, code=code) for msg in exc_info.messages ] @@ -73,20 +73,30 @@ def build_error_from_django_validation_error(exc_info): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') +class ValidationErrorMessage(six.text_type): + code = None + + def __new__(cls, string, code=None, *args, **kwargs): + self = super(ValidationErrorMessage, cls).__new__( + cls, string, *args, **kwargs) + + self.code = code + return self + + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST def __init__(self, detail, code=None): + # If code is there, this means we are dealing with a message. + if code and not isinstance(detail, ValidationErrorMessage): + detail = ValidationErrorMessage(detail, code=code) + # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)): - assert code is None, ( - 'The `code` argument must not be set for compound errors.') - - self.detail = detail - self.code = code + self.detail = _force_text_recursive(detail) def __str__(self): return six.text_type(self.detail) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 39a5e33955..3962595840 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -509,7 +509,7 @@ def run_validators(self, value): # attempting to accumulate a list of errors. if isinstance(exc.detail, dict): raise - errors.append(ValidationError(exc.detail, code=exc.code)) + errors.extend(exc.detail) except DjangoValidationError as exc: errors.extend( build_error_from_django_validation_error(exc) diff --git a/rest_framework/response.py b/rest_framework/response.py index e9ceb27419..4b863cb997 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -38,6 +38,7 @@ def __init__(self, data=None, status=None, '`.error`. representation.' ) raise AssertionError(msg) + self.data = data self.template_name = template_name self.exception = exception diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a4dbc64492..5b3ef37709 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -25,6 +25,7 @@ from rest_framework import exceptions from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.exceptions import ValidationErrorMessage from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -220,14 +221,7 @@ def is_valid(self, raise_exception=False): self._errors = {} if self._errors and raise_exception: - return_errors = None - if isinstance(self._errors, list): - return_errors = ReturnList(self._errors, serializer=self) - elif isinstance(self._errors, dict): - return_errors = ReturnDict(self._errors, serializer=self) - - raise ValidationError(return_errors) - + raise ValidationError(self.errors) return not bool(self._errors) @property @@ -251,42 +245,12 @@ def data(self): self._data = self.get_initial() return self._data - def _transform_to_legacy_errors(self, errors_to_transform): - # Do not mutate `errors_to_transform` here. - errors = ReturnDict(serializer=self) - for field_name, values in errors_to_transform.items(): - if isinstance(values, list): - errors[field_name] = values - continue - - if isinstance(values.detail, list): - errors[field_name] = [] - for value in values.detail: - if isinstance(value, ValidationError): - errors[field_name].extend(value.detail) - elif isinstance(value, list): - errors[field_name].extend(value) - else: - errors[field_name].append(value) - - elif isinstance(values.detail, dict): - errors[field_name] = {} - for sub_field_name, value in values.detail.items(): - errors[field_name][sub_field_name] = [] - for validation_error in value: - errors[field_name][sub_field_name].extend(validation_error.detail) - return errors - @property def errors(self): if not hasattr(self, '_errors'): msg = 'You must call `.is_valid()` before accessing `.errors`.' raise AssertionError(msg) - - if isinstance(self._errors, list): - return map(self._transform_to_legacy_errors, self._errors) - else: - return self._transform_to_legacy_errors(self._errors) + return self._errors @property def validated_data(self): @@ -461,7 +425,7 @@ def to_internal_value(self, data): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) - error = ValidationError(message, code='invalid') + error = ValidationErrorMessage(message, code='invalid') raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [error] }) @@ -478,7 +442,7 @@ def to_internal_value(self, data): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = exc + errors[field.field_name] = exc.detail except DjangoValidationError as exc: errors[field.field_name] = ( exceptions.build_error_from_django_validation_error(exc) @@ -621,7 +585,7 @@ def to_internal_value(self, data): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) - error = ValidationError( + error = ValidationErrorMessage( message, code='not_a_list' ) @@ -630,9 +594,11 @@ def to_internal_value(self, data): }) if not self.allow_empty and len(data) == 0: - message = self.error_messages['empty'] + message = ValidationErrorMessage( + self.error_messages['empty'], + code='empty_not_allowed') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')] + api_settings.NON_FIELD_ERRORS_KEY: [message] }) ret = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 90483eeeb9..3b8678a70d 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ValidationError, ValidationErrorMessage from rest_framework.utils.representation import smart_repr @@ -120,9 +120,10 @@ def enforce_required_fields(self, attrs): return missing = { - field_name: ValidationError( + field_name: ValidationErrorMessage( self.missing_message, code='required') + for field_name in self.fields if field_name not in attrs } @@ -168,8 +169,9 @@ def __call__(self, attrs): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names), - code='unique') + raise ValidationError( + self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -207,7 +209,7 @@ def enforce_required_fields(self, attrs): 'required' state on the fields they are applied to. """ missing = { - field_name: ValidationError( + field_name: ValidationErrorMessage( self.missing_message, code='required') for field_name in [self.field, self.date_field] @@ -235,8 +237,9 @@ def __call__(self, attrs): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - error = ValidationError(message, code='unique') - raise ValidationError({self.field: error}) + raise ValidationError({ + self.field: ValidationErrorMessage(message, code='unique'), + }) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/rest_framework/views.py b/rest_framework/views.py index 8b6f060d46..15d8c6cde2 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -71,7 +71,7 @@ def exception_handler(exc, context): headers['Retry-After'] = '%d' % exc.wait if isinstance(exc.detail, (list, dict)): - data = exc.detail.serializer.errors + data = exc.detail else: data = {'detail': exc.detail} diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 0000000000..a9d244176d --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,74 @@ +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithCode(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + return_errors = {} + for field_name, errors in exc.detail.items(): + return_errors[field_name] = [] + for error in errors: + return_errors[field_name].append({ + 'code': error.code, + 'message': error + }) + + return Response(return_errors, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': [{ + 'message': 'This field is required.', + 'code': 'required', + }], + 'integer': [{ + 'message': 'This field is required.', + 'code': 'required' + }], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data)