Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't import authtoken model until needed #3785

Merged
merged 4 commits into from
Jan 5, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions rest_framework/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authtoken.models import Token


def get_authorization_header(request):
Expand Down Expand Up @@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""

model = Token
model = None

def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token

"""
A custom token model may be used, but must have the following properties.

Expand Down Expand Up @@ -179,9 +185,10 @@ def authenticate(self, request):
return self.authenticate_credentials(token)

def authenticate_credentials(self, key):
model = self.get_model()
try:
token = self.model.objects.select_related('user').get(key=key)
except self.model.DoesNotExist:
token = model.objects.select_related('user').get(key=key)
except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))

if not token.user.is_active:
Expand Down
8 changes: 0 additions & 8 deletions rest_framework/authtoken/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ class Token(models.Model):
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)

class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/tomchristie/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS

def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
Expand Down
49 changes: 39 additions & 10 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from django.conf.urls import include, url
from django.contrib.auth.models import User
from django.db import models
from django.http import HttpResponse
from django.test import TestCase
from django.utils import six
Expand All @@ -25,6 +26,15 @@
factory = APIRequestFactory()


class CustomToken(models.Model):
key = models.CharField(max_length=40, primary_key=True)
user = models.OneToOneField(User)


class CustomTokenAuthentication(TokenAuthentication):
model = CustomToken


class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)

Expand All @@ -42,6 +52,7 @@ def put(self, request):
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
]
Expand Down Expand Up @@ -142,9 +153,11 @@ def test_post_form_session_auth_failing(self):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)


class TokenAuthTests(TestCase):
class BaseTokenAuthTests(object):
"""Token authentication"""
urls = 'tests.test_authentication'
model = None
path = None

def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
Expand All @@ -154,54 +167,65 @@ def setUp(self):
self.user = User.objects.create_user(self.username, self.email, self.password)

self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user)
self.token = self.model.objects.create(key=self.key, user=self.user)

def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
auth = 'Token wxyz6789'
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_fail_post_form_passing_invalid_token_auth(self):
# add an 'invalid' unicode character
auth = 'Token ' + self.key + "¸"
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_post_json_makes_one_db_query(self):
"""Ensure that authenticating a user using a token performs only one DB query"""
auth = "Token " + self.key

def func_to_test():
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)

self.assertNumQueries(1, func_to_test)

def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'})
response = self.csrf_client.post(self.path, {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)


class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'

def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
self.token.delete()
token = Token.objects.create(user=self.user)
token = self.model.objects.create(user=self.user)
self.assertTrue(bool(token.key))

def test_generate_key_returns_string(self):
"""Ensure generate_key returns a string"""
token = Token()
token = self.model()
key = token.generate_key()
self.assertTrue(isinstance(key, six.string_types))

Expand Down Expand Up @@ -236,6 +260,11 @@ def test_token_login_form(self):
self.assertEqual(response.data['token'], self.key)


class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'


class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self):
"""
Expand Down