Skip to content

Commit

Permalink
Merge pull request #3785 from sheppard/authtoken-import
Browse files Browse the repository at this point in the history
don't import authtoken model until needed
  • Loading branch information
tomchristie committed Jan 5, 2016
2 parents dceb686 + 4f40714 commit 37f7b76
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
15 changes: 11 additions & 4 deletions rest_framework/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authtoken.models import Token


def get_authorization_header(request):
Expand Down Expand Up @@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""

model = Token
model = None

def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token

"""
A custom token model may be used, but must have the following properties.
Expand Down Expand Up @@ -179,9 +185,10 @@ def authenticate(self, request):
return self.authenticate_credentials(token)

def authenticate_credentials(self, key):
model = self.get_model()
try:
token = self.model.objects.select_related('user').get(key=key)
except self.model.DoesNotExist:
token = model.objects.select_related('user').get(key=key)
except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))

if not token.user.is_active:
Expand Down
8 changes: 0 additions & 8 deletions rest_framework/authtoken/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ class Token(models.Model):
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)

class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/tomchristie/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS

def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
Expand Down
49 changes: 39 additions & 10 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from django.conf.urls import include, url
from django.contrib.auth.models import User
from django.db import models
from django.http import HttpResponse
from django.test import TestCase
from django.utils import six
Expand All @@ -25,6 +26,15 @@
factory = APIRequestFactory()


class CustomToken(models.Model):
key = models.CharField(max_length=40, primary_key=True)
user = models.OneToOneField(User)


class CustomTokenAuthentication(TokenAuthentication):
model = CustomToken


class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)

Expand All @@ -42,6 +52,7 @@ def put(self, request):
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
]
Expand Down Expand Up @@ -142,9 +153,11 @@ def test_post_form_session_auth_failing(self):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)


class TokenAuthTests(TestCase):
class BaseTokenAuthTests(object):
"""Token authentication"""
urls = 'tests.test_authentication'
model = None
path = None

def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
Expand All @@ -154,54 +167,65 @@ def setUp(self):
self.user = User.objects.create_user(self.username, self.email, self.password)

self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user)
self.token = self.model.objects.create(key=self.key, user=self.user)

def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
auth = 'Token wxyz6789'
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_fail_post_form_passing_invalid_token_auth(self):
# add an 'invalid' unicode character
auth = 'Token ' + self.key + "¸"
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_post_json_makes_one_db_query(self):
"""Ensure that authenticating a user using a token performs only one DB query"""
auth = "Token " + self.key

def func_to_test():
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)

self.assertNumQueries(1, func_to_test)

def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'})
response = self.csrf_client.post(self.path, {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)


class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'

def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
self.token.delete()
token = Token.objects.create(user=self.user)
token = self.model.objects.create(user=self.user)
self.assertTrue(bool(token.key))

def test_generate_key_returns_string(self):
"""Ensure generate_key returns a string"""
token = Token()
token = self.model()
key = token.generate_key()
self.assertTrue(isinstance(key, six.string_types))

Expand Down Expand Up @@ -236,6 +260,11 @@ def test_token_login_form(self):
self.assertEqual(response.data['token'], self.key)


class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'


class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self):
"""
Expand Down

1 comment on commit 37f7b76

@miclovich
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) looked for this the whole night! Awesome.

Please sign in to comment.