Skip to content

Commit

Permalink
Allow customizing token JSON encoding (#568)
Browse files Browse the repository at this point in the history
* Allow specifying custom JSONEncoder for TokenBackend

* Make TokenBackend JSONEncoder configurable
  • Loading branch information
vainu-arto authored May 4, 2022
1 parent 70b8f84 commit 58b1874
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
12 changes: 10 additions & 2 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from datetime import timedelta
from typing import Union
from typing import Optional, Type, Union

import jwt
from django.utils.translation import gettext_lazy as _
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
issuer=None,
jwk_url: str = None,
leeway: Union[float, int, timedelta] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
):
self._validate_algorithm(algorithm)

Expand All @@ -53,6 +55,7 @@ def __init__(
self.jwks_client = None

self.leeway = leeway
self.json_encoder = json_encoder

def _validate_algorithm(self, algorithm):
"""
Expand Down Expand Up @@ -107,7 +110,12 @@ def encode(self, payload):
if self.issuer is not None:
jwt_payload["iss"] = self.issuer

token = jwt.encode(jwt_payload, self.signing_key, algorithm=self.algorithm)
token = jwt.encode(
jwt_payload,
self.signing_key,
algorithm=self.algorithm,
json_encoder=self.json_encoder,
)
if isinstance(token, bytes):
# For PyJWT <= 1.7.1
return token.decode("utf-8")
Expand Down
2 changes: 2 additions & 0 deletions rest_framework_simplejwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"VERIFYING_KEY": "",
"AUDIENCE": None,
"ISSUER": None,
"JSON_ENCODER": None,
"JWK_URL": None,
"LEEWAY": 0,
"AUTH_HEADER_TYPES": ("Bearer",),
Expand All @@ -44,6 +45,7 @@

IMPORT_STRINGS = (
"AUTH_TOKEN_CLASSES",
"JSON_ENCODER",
"TOKEN_USER_CLASS",
"USER_AUTHENTICATION_RULE",
)
Expand Down
1 change: 1 addition & 0 deletions rest_framework_simplejwt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
api_settings.ISSUER,
api_settings.JWK_URL,
api_settings.LEEWAY,
api_settings.JSON_ENCODER,
)
17 changes: 17 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import uuid
from datetime import datetime, timedelta
from json import JSONEncoder
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -34,6 +36,13 @@
IS_OLD_JWT = jwt_version == "1.7.1"


class UUIDJSONEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, uuid.UUID):
return str(obj)
return super().default(obj)


class TestTokenBackend(TestCase):
def setUp(self):
self.hmac_token_backend = TokenBackend("HS256", SECRET)
Expand Down Expand Up @@ -329,3 +338,11 @@ def test_decode_leeway_hmac_success(self):
self.hmac_leeway_token_backend.decode(expired_token),
self.payload,
)

def test_custom_JSONEncoder(self):
backend = TokenBackend("HS256", SECRET, json_encoder=UUIDJSONEncoder)
unique = uuid.uuid4()
self.payload["uuid"] = unique
token = backend.encode(self.payload)
decoded = backend.decode(token)
self.assertEqual(decoded["uuid"], str(unique))

0 comments on commit 58b1874

Please sign in to comment.