From c0510fe6ae83eb023c13a56c2f8455dcc7b1e75e Mon Sep 17 00:00:00 2001 From: Situphen Date: Sun, 1 Oct 2023 17:01:16 +0200 Subject: [PATCH] =?UTF-8?q?Mise=20=C3=A0=20jour=20de=20django-oauth-toolki?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 2 +- zds/api/utils.py | 26 +++++++++++ zds/gallery/api/tests.py | 20 +++------ zds/member/api/tests.py | 82 +++++++++-------------------------- zds/mp/api/tests.py | 41 ++++++------------ zds/notification/api/tests.py | 5 +-- 6 files changed, 69 insertions(+), 107 deletions(-) create mode 100644 zds/api/utils.py diff --git a/requirements.txt b/requirements.txt index 594c8b6f3a..5f3a8f01b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ requests==2.31.0 # Api dependencies django-cors-headers==4.3.1 django-filter==23.5 -django-oauth-toolkit==1.7.0 +django-oauth-toolkit==2.3.0 djangorestframework==3.14.0 drf-extensions==0.7.1 dry-rest-permissions==0.1.10 diff --git a/zds/api/utils.py b/zds/api/utils.py new file mode 100644 index 0000000000..728d8968d3 --- /dev/null +++ b/zds/api/utils.py @@ -0,0 +1,26 @@ +from oauth2_provider.models import Application, AccessToken + +CLEARTEXT_SECRET = "abcdefghijklmnopqrstuvwxyz1234567890" + + +def authenticate_oauth2_client(client, user, password): + oauth2_client = Application.objects.create( + user=user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_PASSWORD, + client_secret=CLEARTEXT_SECRET, + ) + oauth2_client.save() + + client.post( + "/oauth2/token/", + { + "client_id": oauth2_client.client_id, + "client_secret": CLEARTEXT_SECRET, + "username": user.username, + "password": password, + "grant_type": "password", + }, + ) + access_token = AccessToken.objects.get(user=user) + client.credentials(HTTP_AUTHORIZATION=f"Bearer {access_token}") diff --git a/zds/gallery/api/tests.py b/zds/gallery/api/tests.py index 3859d7f216..7083d03d79 100644 --- a/zds/gallery/api/tests.py +++ b/zds/gallery/api/tests.py @@ -9,10 +9,10 @@ from rest_framework.test import APITestCase, APIClient from rest_framework_extensions.settings import extensions_api_settings +from zds.api.utils import authenticate_oauth2_client from zds.gallery.tests.factories import UserGalleryFactory, GalleryFactory, ImageFactory from zds.gallery.models import Gallery, UserGallery, GALLERY_WRITE, Image, GALLERY_READ from zds.member.tests.factories import ProfileFactory -from zds.member.api.tests import create_oauth2_client, authenticate_client from zds.tutorialv2.tests.factories import PublishableContentFactory from zds.tutorialv2.tests import TutorialTestMixin, override_for_contents @@ -21,8 +21,7 @@ class GalleryListAPITest(APITestCase): def setUp(self): self.profile = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() @@ -92,8 +91,7 @@ def setUp(self): self.profile = ProfileFactory() self.other = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.gallery = GalleryFactory() @@ -222,8 +220,7 @@ def setUp(self): self.profile = ProfileFactory() self.other = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.gallery = GalleryFactory() UserGalleryFactory(user=self.profile.user, gallery=self.gallery) @@ -358,8 +355,7 @@ def setUp(self): self.profile = ProfileFactory() self.other = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.gallery = GalleryFactory() UserGalleryFactory(user=self.profile.user, gallery=self.gallery) @@ -506,8 +502,7 @@ def setUp(self): self.other = ProfileFactory() self.client = APIClient() self.new_participant = ProfileFactory() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.gallery = GalleryFactory() UserGalleryFactory(user=self.profile.user, gallery=self.gallery) @@ -620,8 +615,7 @@ def setUp(self): self.other = ProfileFactory() self.new_participant = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.gallery = GalleryFactory() UserGalleryFactory(user=self.profile.user, gallery=self.gallery) diff --git a/zds/member/api/tests.py b/zds/member/api/tests.py index 425b08c537..e86f100626 100644 --- a/zds/member/api/tests.py +++ b/zds/member/api/tests.py @@ -2,12 +2,12 @@ from django.contrib.auth.models import User, Group from django.core import mail from django.urls import reverse -from oauth2_provider.models import Application, AccessToken from rest_framework import status from rest_framework.test import APITestCase from rest_framework.test import APIClient from zds.api.pagination import REST_PAGE_SIZE, REST_MAX_PAGE_SIZE, REST_PAGE_SIZE_QUERY_PARAM +from zds.api.utils import authenticate_oauth2_client from zds.member.tests.factories import ProfileFactory, StaffProfileFactory, ProfileNotSyncFactory from zds.member.models import TokenRegister, BannedEmailProvider from rest_framework_extensions.settings import extensions_api_settings @@ -366,9 +366,8 @@ def test_detail_of_the_member(self): Gets all information about the user. """ profile = ProfileFactory() - client_oauth2 = create_oauth2_client(profile.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, profile.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, profile.user, "hostel77") response = client_authenticated.get(reverse("api:member:profile")) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -403,9 +402,8 @@ def setUp(self): self.client = APIClient() self.profile = ProfileFactory() - client_oauth2 = create_oauth2_client(self.profile.user) self.client_authenticated = APIClient() - authenticate_client(self.client_authenticated, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client_authenticated, self.profile.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() @@ -497,9 +495,8 @@ def test_update_member_details_with_user_not_synchronized(self): """ decal = ProfileNotSyncFactory() - client_oauth2 = create_oauth2_client(decal.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, decal.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, decal.user, "hostel77") response = client_authenticated.put(reverse("api:member:detail", args=[decal.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -681,9 +678,8 @@ def setUp(self): self.profile = ProfileFactory() self.staff = StaffProfileFactory() - client_oauth2 = create_oauth2_client(self.staff.user) self.client_authenticated = APIClient() - authenticate_client(self.client_authenticated, client_oauth2, self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client_authenticated, self.staff.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() @@ -749,9 +745,8 @@ def test_apply_read_only_at_a_member_without_permissions(self): """ Tries to apply a read only sanction at a member with a user isn't authenticated. """ - client_oauth2 = create_oauth2_client(self.profile.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77") response = client_authenticated.post(reverse("api:member:read-only", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -817,9 +812,8 @@ def test_remove_read_only_at_a_member_without_permissions(self): """ Tries to remove a read only sanction at a member with a user isn't authenticated. """ - client_oauth2 = create_oauth2_client(self.profile.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77") response = client_authenticated.delete(reverse("api:member:read-only", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -840,9 +834,8 @@ def setUp(self): self.profile = ProfileFactory() self.staff = StaffProfileFactory() - client_oauth2 = create_oauth2_client(self.staff.user) self.client_authenticated = APIClient() - authenticate_client(self.client_authenticated, client_oauth2, self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client_authenticated, self.staff.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() @@ -908,9 +901,8 @@ def test_apply_ban_at_a_member_without_permissions(self): """ Tries to apply a ban sanction at a member with a user isn't authenticated. """ - client_oauth2 = create_oauth2_client(self.profile.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77") response = client_authenticated.post(reverse("api:member:ban", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -976,9 +968,8 @@ def test_remove_ban_at_a_member_without_permissions(self): """ Tries to remove a ban sanction at a member with a user isn't authenticated. """ - client_oauth2 = create_oauth2_client(self.profile.user) client_authenticated = APIClient() - authenticate_client(client_authenticated, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(client_authenticated, self.profile.user, "hostel77") response = client_authenticated.delete(reverse("api:member:ban", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -1011,9 +1002,7 @@ def test_has_read_permission_for_authenticated_users(self): """ Authenticated users have the permission to read any member. """ - authenticate_client( - self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1023,7 +1012,7 @@ def test_has_read_permission_for_staff_users(self): """ Staff users have the permission to read any member. """ - authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.staff.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1041,9 +1030,7 @@ def test_has_write_permission_for_authenticated_user(self): """ A user authenticated have write permissions. """ - authenticate_client( - self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1053,7 +1040,7 @@ def test_has_write_permission_for_staff(self): """ A staff user have write permissions. """ - authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.staff.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1071,9 +1058,7 @@ def test_has_update_permission_for_authenticated_user(self): """ Only the user authenticated have update permissions. """ - authenticate_client( - self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1083,7 +1068,7 @@ def test_has_not_update_permission_for_staff(self): """ Only the user authenticated have update permissions. """ - authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.staff.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1101,9 +1086,7 @@ def test_has_not_ban_permission_for_authenticated_user(self): """ Only staff have ban permission. """ - authenticate_client( - self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1113,7 +1096,7 @@ def test_has_ban_permission_for_staff(self): """ Only staff have ban permission. """ - authenticate_client(self.client, create_oauth2_client(self.staff.user), self.staff.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.staff.user, "hostel77") response = self.client.get(reverse("api:member:detail", args=[self.profile.user.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1132,14 +1115,12 @@ def test_cache_of_user_authenticated_for_member_profile(self): profile = ProfileFactory() another_profile = ProfileFactory() - authenticate_client(self.client, create_oauth2_client(profile.user), profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, profile.user, "hostel77") response = self.client.get(reverse("api:member:profile")) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(profile.user.username, response.data.get("username")) - authenticate_client( - self.client, create_oauth2_client(another_profile.user), another_profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, another_profile.user, "hostel77") response = self.client.get(reverse("api:member:profile")) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(another_profile.user.username, response.data.get("username")) @@ -1167,26 +1148,3 @@ def test_cache_invalidated_when_new_member(self): response = self.client.get(reverse("api:member:list")) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data.get("count"), count + 1) - - -def create_oauth2_client(user): - client = Application.objects.create( - user=user, client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_PASSWORD - ) - client.save() - return client - - -def authenticate_client(client, client_auth, username, password): - client.post( - "/oauth2/token/", - { - "client_id": client_auth.client_id, - "client_secret": client_auth.client_secret, - "username": username, - "password": password, - "grant_type": "password", - }, - ) - access_token = AccessToken.objects.get(user__username=username) - client.credentials(HTTP_AUTHORIZATION=f"Bearer {access_token}") diff --git a/zds/mp/api/tests.py b/zds/mp/api/tests.py index 159e36cb97..ad6503dcb5 100644 --- a/zds/mp/api/tests.py +++ b/zds/mp/api/tests.py @@ -10,7 +10,7 @@ from rest_framework_extensions.settings import extensions_api_settings from zds.api.pagination import REST_PAGE_SIZE, REST_MAX_PAGE_SIZE, REST_PAGE_SIZE_QUERY_PARAM -from zds.member.api.tests import create_oauth2_client, authenticate_client +from zds.api.utils import authenticate_oauth2_client from zds.member.tests.factories import ProfileFactory, UserFactory from zds.mp.tests.factories import PrivateTopicFactory, PrivatePostFactory from zds.mp.models import PrivateTopic, PrivatePostVote @@ -20,8 +20,7 @@ class PrivateTopicListAPITest(APITestCase): def setUp(self): self.profile = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.bot_group = Group() self.bot_group.name = settings.ZDS_APP["member"]["bot_group"] @@ -395,8 +394,7 @@ def setUp(self): author=self.profile.user, privatetopic=self.private_topic, position_in_topic=1 ) self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.bot_group = Group() self.bot_group.name = settings.ZDS_APP["member"]["bot_group"] @@ -531,8 +529,7 @@ def test_update_private_topic_with_user_not_author(self): self.private_topic.participants.add(another_profile.user) self.client = APIClient() - client_oauth2 = create_oauth2_client(another_profile.user) - authenticate_client(self.client, client_oauth2, another_profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, another_profile.user, "hostel77") data = { "title": "Good title", @@ -550,8 +547,7 @@ def test_add_participant_with_an_user_not_author_of_private_topic(self): self.private_topic.participants.add(another_profile.user) self.client = APIClient() - client_oauth2 = create_oauth2_client(another_profile.user) - authenticate_client(self.client, client_oauth2, another_profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, another_profile.user, "hostel77") data = {"participants": [third_profile.user.id]} response = self.client.put(reverse("api:mp:detail", args=[self.private_topic.id]), data) @@ -628,8 +624,7 @@ class PrivatePostListAPI(APITestCase): def setUp(self): self.profile = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.private_topic = PrivateTopicFactory(author=self.profile.user) self.private_topic.participants.add(ProfileFactory().user) @@ -869,8 +864,7 @@ def setUp(self): author=self.profile.user, privatetopic=self.private_topic, position_in_topic=1 ) self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() @@ -1013,13 +1007,11 @@ class PrivateTopicUnreadListAPITest(APITestCase): def setUp(self): self.profile = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") self.another_profile = ProfileFactory() self.another_client = APIClient() - another_client_oauth2 = create_oauth2_client(self.another_profile.user) - authenticate_client(self.another_client, another_client_oauth2, self.another_profile.user.username, "hostel77") + authenticate_oauth2_client(self.another_client, self.another_profile.user, "hostel77") self.bot_group = Group() self.bot_group.name = settings.ZDS_APP["member"]["bot_group"] @@ -1071,9 +1063,7 @@ def setUp(self): author=self.profile.user, privatetopic=self.private_topic, position_in_topic=1 ) - authenticate_client( - self.client, create_oauth2_client(self.profile.user), self.profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") def test_has_read_permission_for_authenticated_users(self): """ @@ -1106,9 +1096,7 @@ def test_has_not_update_permission_for_authenticated_users_and_but_not_author(se another_profile = ProfileFactory() self.private_topic.participants.add(another_profile.user) - authenticate_client( - self.client, create_oauth2_client(another_profile.user), another_profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, another_profile.user, "hostel77") response = self.client.get(reverse("api:mp:detail", args=[self.private_topic.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertFalse(response.data.get("permissions").get("update")) @@ -1144,9 +1132,7 @@ def test_has_not_update_permission_for_authenticated_users_and_but_not_author_fo another_profile = ProfileFactory() self.private_topic.participants.add(another_profile.user) - authenticate_client( - self.client, create_oauth2_client(another_profile.user), another_profile.user.username, "hostel77" - ) + authenticate_oauth2_client(self.client, another_profile.user, "hostel77") response = self.client.get(reverse("api:mp:message-detail", args=[self.private_topic.id, self.private_post.id])) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertFalse(response.data.get("permissions").get("update")) @@ -1160,8 +1146,7 @@ def setUp(self): author=self.profile.user, privatetopic=self.private_topic, position_in_topic=1 ) self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() diff --git a/zds/notification/api/tests.py b/zds/notification/api/tests.py index 1e155c6f69..0bc3f129ac 100644 --- a/zds/notification/api/tests.py +++ b/zds/notification/api/tests.py @@ -5,7 +5,7 @@ from rest_framework.test import APITestCase from rest_framework_extensions.settings import extensions_api_settings -from zds.member.api.tests import create_oauth2_client, authenticate_client +from zds.api.utils import authenticate_oauth2_client from zds.member.tests.factories import ProfileFactory from zds.mp.tests.factories import PrivateTopicFactory from zds.notification.models import Notification @@ -16,8 +16,7 @@ class NotificationListAPITest(APITestCase): def setUp(self): self.profile = ProfileFactory() self.client = APIClient() - client_oauth2 = create_oauth2_client(self.profile.user) - authenticate_client(self.client, client_oauth2, self.profile.user.username, "hostel77") + authenticate_oauth2_client(self.client, self.profile.user, "hostel77") caches[extensions_api_settings.DEFAULT_USE_CACHE].clear() def test_list_of_notifications_empty(self):