diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 9e73ef6324..0ca90873eb 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -10,7 +10,6 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import HTTP_HEADER_ENCODING, exceptions -from rest_framework.authtoken.models import Token def get_authorization_header(request): @@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication): Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a """ - model = Token + model = None + + def get_model(self): + if self.model is not None: + return self.model + from rest_framework.authtoken.models import Token + return Token + """ A custom token model may be used, but must have the following properties. @@ -179,9 +185,10 @@ def authenticate(self, request): return self.authenticate_credentials(token) def authenticate_credentials(self, key): + model = self.get_model() try: - token = self.model.objects.select_related('user').get(key=key) - except self.model.DoesNotExist: + token = model.objects.select_related('user').get(key=key) + except model.DoesNotExist: raise exceptions.AuthenticationFailed(_('Invalid token.')) if not token.user.is_active: diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index b329ee65f8..2fef61e53b 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -21,14 +21,6 @@ class Token(models.Model): user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) - class Meta: - # Work around for a bug in Django: - # https://code.djangoproject.com/ticket/19422 - # - # Also see corresponding ticket: - # https://github.com/tomchristie/django-rest-framework/issues/705 - abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS - def save(self, *args, **kwargs): if not self.key: self.key = self.generate_key() diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 91434124ef..285a3210ce 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -6,6 +6,7 @@ from django.conf.urls import include, url from django.contrib.auth.models import User +from django.db import models from django.http import HttpResponse from django.test import TestCase from django.utils import six @@ -25,6 +26,15 @@ factory = APIRequestFactory() +class CustomToken(models.Model): + key = models.CharField(max_length=40, primary_key=True) + user = models.OneToOneField(User) + + +class CustomTokenAuthentication(TokenAuthentication): + model = CustomToken + + class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) @@ -42,6 +52,7 @@ def put(self, request): url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), + url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])), url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), ] @@ -142,9 +153,11 @@ def test_post_form_session_auth_failing(self): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -class TokenAuthTests(TestCase): +class BaseTokenAuthTests(object): """Token authentication""" urls = 'tests.test_authentication' + model = None + path = None def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True) @@ -154,24 +167,30 @@ def setUp(self): self.user = User.objects.create_user(self.username, self.email, self.password) self.key = 'abcd1234' - self.token = Token.objects.create(key=self.key, user=self.user) + self.token = self.model.objects.create(key=self.key, user=self.user) def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" auth = 'Token ' + self.key - response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_fail_post_form_passing_nonexistent_token_auth(self): + # use a nonexistent token key + auth = 'Token wxyz6789' + response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + def test_fail_post_form_passing_invalid_token_auth(self): # add an 'invalid' unicode character auth = 'Token ' + self.key + "ΒΈ" - response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_makes_one_db_query(self): @@ -179,29 +198,34 @@ def test_post_json_makes_one_db_query(self): auth = "Token " + self.key def func_to_test(): - return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) + return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertNumQueries(1, func_to_test) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', {'example': 'example'}) + response = self.csrf_client.post(self.path, {'example': 'example'}) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') + response = self.csrf_client.post(self.path, {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +class TokenAuthTests(BaseTokenAuthTests, TestCase): + model = Token + path = '/token/' + def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" self.token.delete() - token = Token.objects.create(user=self.user) + token = self.model.objects.create(user=self.user) self.assertTrue(bool(token.key)) def test_generate_key_returns_string(self): """Ensure generate_key returns a string""" - token = Token() + token = self.model() key = token.generate_key() self.assertTrue(isinstance(key, six.string_types)) @@ -236,6 +260,11 @@ def test_token_login_form(self): self.assertEqual(response.data['token'], self.key) +class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): + model = CustomToken + path = '/customtoken/' + + class IncorrectCredentialsTests(TestCase): def test_incorrect_credentials(self): """