diff --git a/src/keycloak/keycloak_openid.py b/src/keycloak/keycloak_openid.py index 1a2954f..d2c3b3d 100644 --- a/src/keycloak/keycloak_openid.py +++ b/src/keycloak/keycloak_openid.py @@ -28,7 +28,7 @@ class to handle authentication and token manipulation. """ import json -from typing import Optional +from typing import Optional, Union from jwcrypto import jwk, jwt @@ -581,6 +581,33 @@ def introspect(self, token, rpt=None, token_type_hint=None): ) return raise_error_from_response(data_raw, KeycloakPostError) + @staticmethod + def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs): + """Decode and optionally validate a token. + + :param token: The token to verify + :type token: str + :param key: Which key should be used for validation. + If not provided, the validation is not performed and the token is implicitly valid. + :type key: Union[jwk.JWK, jwk.JWKSet, None] + :param kwargs: Additional keyword arguments for jwcrypto's JWT object + :type kwargs: dict + :returns: Decoded token + """ + # keep the function free of IO + # this way it can be used by `decode_token` and `a_decode_token` + + if key is not None: + leeway = kwargs.pop("leeway", 60) + full_jwt = jwt.JWT(jwt=token, **kwargs) + full_jwt.leeway = leeway + full_jwt.validate(key) + return jwt.json_decode(full_jwt.claims) + else: + full_jwt = jwt.JWT(jwt=token, **kwargs) + full_jwt.token.objects["valid"] = True + return json.loads(full_jwt.token.payload.decode("utf-8")) + def decode_token(self, token, validate: bool = True, **kwargs): """Decode user token. @@ -603,26 +630,19 @@ def decode_token(self, token, validate: bool = True, **kwargs): :returns: Decoded token :rtype: dict """ + key = kwargs.pop("key", None) if validate: - if "key" not in kwargs: + if key is None: key = ( "-----BEGIN PUBLIC KEY-----\n" + self.public_key() + "\n-----END PUBLIC KEY-----" ) key = jwk.JWK.from_pem(key.encode("utf-8")) - kwargs["key"] = key - - key = kwargs.pop("key") - leeway = kwargs.pop("leeway", 60) - full_jwt = jwt.JWT(jwt=token, **kwargs) - full_jwt.leeway = leeway - full_jwt.validate(key) - return jwt.json_decode(full_jwt.claims) else: - full_jwt = jwt.JWT(jwt=token, **kwargs) - full_jwt.token.objects["valid"] = True - return json.loads(full_jwt.token.payload.decode("utf-8")) + key = None + + return self._verify_token(token, key, **kwargs) def load_authorization_config(self, path): """Load Keycloak settings (authorization). @@ -1273,22 +1293,19 @@ async def a_decode_token(self, token, validate: bool = True, **kwargs): :returns: Decoded token :rtype: dict """ + key = kwargs.pop("key", None) if validate: - if "key" not in kwargs: + if key is None: key = ( "-----BEGIN PUBLIC KEY-----\n" + await self.a_public_key() + "\n-----END PUBLIC KEY-----" ) key = jwk.JWK.from_pem(key.encode("utf-8")) - kwargs["key"] = key - - full_jwt = jwt.JWT(jwt=token, **kwargs) - return jwt.json_decode(full_jwt.claims) else: - full_jwt = jwt.JWT(jwt=token, **kwargs) - full_jwt.token.objects["valid"] = True - return json.loads(full_jwt.token.payload.decode("utf-8")) + key = None + + return self._verify_token(token, key, **kwargs) async def a_load_authorization_config(self, path): """Load Keycloak settings (authorization) asynchronously. diff --git a/tests/test_keycloak_openid.py b/tests/test_keycloak_openid.py index f26c9b8..d28ba97 100644 --- a/tests/test_keycloak_openid.py +++ b/tests/test_keycloak_openid.py @@ -4,6 +4,8 @@ from typing import Tuple from unittest import mock +import jwcrypto.jwk +import jwcrypto.jws import pytest from keycloak import KeycloakAdmin, KeycloakOpenID @@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]): assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token +def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]): + """Test decode token with an invalid token. + + :param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials + :type oid_with_credentials: Tuple[KeycloakOpenID, str, str] + """ + oid, username, password = oid_with_credentials + token = oid.token(username=username, password=password) + access_token = token["access_token"] + decoded_access_token = oid.decode_token(token=access_token) + + key = oid.public_key() + key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----" + key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8")) + + invalid_access_token = access_token + "a" + with pytest.raises(jwcrypto.jws.InvalidJWSSignature): + decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True) + + with pytest.raises(jwcrypto.jws.InvalidJWSSignature): + decoded_invalid_access_token = oid.decode_token( + token=invalid_access_token, validate=True, key=key + ) + + decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False) + assert decoded_access_token == decoded_invalid_access_token + + decoded_invalid_access_token = oid.decode_token( + token=invalid_access_token, validate=False, key=key + ) + assert decoded_access_token == decoded_invalid_access_token + + def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]): """Test load authorization config. @@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str @pytest.mark.asyncio async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]): - """Test decode token. + """Test decode token asynchronously. :param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials :type oid_with_credentials: Tuple[KeycloakOpenID, str, str] @@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token +@pytest.mark.asyncio +async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]): + """Test decode token asynchronously an invalid token. + + :param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials + :type oid_with_credentials: Tuple[KeycloakOpenID, str, str] + """ + oid, username, password = oid_with_credentials + token = await oid.a_token(username=username, password=password) + access_token = token["access_token"] + decoded_access_token = await oid.a_decode_token(token=access_token) + + key = await oid.a_public_key() + key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----" + key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8")) + + invalid_access_token = access_token + "a" + with pytest.raises(jwcrypto.jws.InvalidJWSSignature): + decoded_invalid_access_token = await oid.a_decode_token( + token=invalid_access_token, validate=True + ) + + with pytest.raises(jwcrypto.jws.InvalidJWSSignature): + decoded_invalid_access_token = await oid.a_decode_token( + token=invalid_access_token, validate=True, key=key + ) + + decoded_invalid_access_token = await oid.a_decode_token( + token=invalid_access_token, validate=False + ) + assert decoded_access_token == decoded_invalid_access_token + + decoded_invalid_access_token = await oid.a_decode_token( + token=invalid_access_token, validate=False, key=key + ) + assert decoded_access_token == decoded_invalid_access_token + + @pytest.mark.asyncio async def test_a_load_authorization_config( oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]