From 378b04eeaa4a6be3b174374746e6b126746d7635 Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Wed, 10 Aug 2016 16:19:56 +0200 Subject: [PATCH] Fix handling of ALLOWED_VERSIONS and no DEFAULT_VERSION (#4370) When only `ALLOWED_VERSIONS` but no `DEFAULT_VERSION` is specified, a version should be enforced. --- docs/api-guide/versioning.md | 4 +- rest_framework/versioning.py | 3 +- tests/test_versioning.py | 82 ++++++++++++++++++++++++++++++++---- 3 files changed, 78 insertions(+), 11 deletions(-) diff --git a/docs/api-guide/versioning.md b/docs/api-guide/versioning.md index 54aa0170dc..29672c96ea 100644 --- a/docs/api-guide/versioning.md +++ b/docs/api-guide/versioning.md @@ -71,8 +71,8 @@ You can also set the versioning scheme on an individual view. Typically you won' The following settings keys are also used to control versioning: * `DEFAULT_VERSION`. The value that should be used for `request.version` when no versioning information is present. Defaults to `None`. -* `ALLOWED_VERSIONS`. If set, this value will restrict the set of versions that may be returned by the versioning scheme, and will raise an error if the provided version if not in this set. Note that the value used for the `DEFAULT_VERSION` setting is always considered to be part of the `ALLOWED_VERSIONS` set. Defaults to `None`. -* `VERSION_PARAM`. The string that should used for any versioning parameters, such as in the media type or URL query parameters. Defaults to `'version'`. +* `ALLOWED_VERSIONS`. If set, this value will restrict the set of versions that may be returned by the versioning scheme, and will raise an error if the provided version is not in this set. Note that the value used for the `DEFAULT_VERSION` setting is always considered to be part of the `ALLOWED_VERSIONS` set (unless it is `None`). Defaults to `None`. +* `VERSION_PARAM`. The string that should be used for any versioning parameters, such as in the media type or URL query parameters. Defaults to `'version'`. You can also set your versioning class plus those three values on a per-view or a per-viewset basis by defining your own versioning scheme and using the `default_version`, `allowed_versions` and `version_param` class variables. For example, if you want to use `URLPathVersioning`: diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py index f533ef580d..e5524afe8a 100644 --- a/rest_framework/versioning.py +++ b/rest_framework/versioning.py @@ -30,7 +30,8 @@ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, * def is_allowed_version(self, version): if not self.allowed_versions: return True - return (version == self.default_version) or (version in self.allowed_versions) + return ((version is not None and version == self.default_version) or + (version in self.allowed_versions)) class AcceptHeaderVersioning(BaseVersioning): diff --git a/tests/test_versioning.py b/tests/test_versioning.py index edac11e62c..195f3fec10 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -44,14 +44,34 @@ def get(self, request, *args, **kwargs): return Response({'url': reverse('another', request=request)}) -class RequestInvalidVersionView(APIView): +class AllowedVersionsView(RequestVersionView): def determine_version(self, request, *args, **kwargs): scheme = self.versioning_class() scheme.allowed_versions = ('v1', 'v2') return (scheme.determine_version(request, *args, **kwargs), scheme) - def get(self, request, *args, **kwargs): - return Response({'version': request.version}) + +class AllowedAndDefaultVersionsView(RequestVersionView): + def determine_version(self, request, *args, **kwargs): + scheme = self.versioning_class() + scheme.allowed_versions = ('v1', 'v2') + scheme.default_version = 'v2' + return (scheme.determine_version(request, *args, **kwargs), scheme) + + +class AllowedWithNoneVersionsView(RequestVersionView): + def determine_version(self, request, *args, **kwargs): + scheme = self.versioning_class() + scheme.allowed_versions = ('v1', 'v2', None) + return (scheme.determine_version(request, *args, **kwargs), scheme) + + +class AllowedWithNoneAndDefaultVersionsView(RequestVersionView): + def determine_version(self, request, *args, **kwargs): + scheme = self.versioning_class() + scheme.allowed_versions = ('v1', 'v2', None) + scheme.default_version = 'v2' + return (scheme.determine_version(request, *args, **kwargs), scheme) factory = APIRequestFactory() @@ -219,7 +239,7 @@ class FakeResolverMatch: class TestInvalidVersion: def test_invalid_query_param_versioning(self): scheme = versioning.QueryParameterVersioning - view = RequestInvalidVersionView.as_view(versioning_class=scheme) + view = AllowedVersionsView.as_view(versioning_class=scheme) request = factory.get('/endpoint/?version=v3') response = view(request) @@ -228,7 +248,7 @@ def test_invalid_query_param_versioning(self): @override_settings(ALLOWED_HOSTS=['*']) def test_invalid_host_name_versioning(self): scheme = versioning.HostNameVersioning - view = RequestInvalidVersionView.as_view(versioning_class=scheme) + view = AllowedVersionsView.as_view(versioning_class=scheme) request = factory.get('/endpoint/', HTTP_HOST='v3.example.org') response = view(request) @@ -236,7 +256,7 @@ def test_invalid_host_name_versioning(self): def test_invalid_accept_header_versioning(self): scheme = versioning.AcceptHeaderVersioning - view = RequestInvalidVersionView.as_view(versioning_class=scheme) + view = AllowedVersionsView.as_view(versioning_class=scheme) request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3') response = view(request) @@ -244,7 +264,7 @@ def test_invalid_accept_header_versioning(self): def test_invalid_url_path_versioning(self): scheme = versioning.URLPathVersioning - view = RequestInvalidVersionView.as_view(versioning_class=scheme) + view = AllowedVersionsView.as_view(versioning_class=scheme) request = factory.get('/v3/endpoint/') response = view(request, version='v3') @@ -255,7 +275,7 @@ class FakeResolverMatch: namespace = 'v3' scheme = versioning.NamespaceVersioning - view = RequestInvalidVersionView.as_view(versioning_class=scheme) + view = AllowedVersionsView.as_view(versioning_class=scheme) request = factory.get('/v3/endpoint/') request.resolver_match = FakeResolverMatch @@ -263,6 +283,52 @@ class FakeResolverMatch: assert response.status_code == status.HTTP_404_NOT_FOUND +class TestAllowedAndDefaultVersion: + def test_missing_without_default(self): + scheme = versioning.AcceptHeaderVersioning + view = AllowedVersionsView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') + response = view(request) + assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE + + def test_missing_with_default(self): + scheme = versioning.AcceptHeaderVersioning + view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') + response = view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'version': 'v2'} + + def test_with_default(self): + scheme = versioning.AcceptHeaderVersioning + view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', + HTTP_ACCEPT='application/json; version=v2') + response = view(request) + assert response.status_code == status.HTTP_200_OK + + def test_missing_without_default_but_none_allowed(self): + scheme = versioning.AcceptHeaderVersioning + view = AllowedWithNoneVersionsView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') + response = view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'version': None} + + def test_missing_with_default_and_none_allowed(self): + scheme = versioning.AcceptHeaderVersioning + view = AllowedWithNoneAndDefaultVersionsView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') + response = view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'version': 'v2'} + + class TestHyperlinkedRelatedField(URLPatternsTestCase): included = [ url(r'^namespaced/(?P\d+)/$', dummy_pk_view, name='namespaced'),