Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RFC-7797 / JWS (Detached Payload) #166 #272

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions jose/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,68 @@ def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256):
return signed_output


def verify(token, key, algorithms, verify=True):
def sign_detached(payload, key, headers=None, algorithm=ALGORITHMS.HS256):
"""Signs a claims set and returns a JWS as a detached payload string, as per RFC7797

Args:
payload (str or dict): A string to sign
key (str or dict): The key to use for signing the claim set. Can be
individual JWK or JWK set.
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
if the signature needs to be generated on encoded payload, then
header has to contain {"b64":True}
algorithm (str, optional): The algorithm to use for signing the
the claims. Defaults to HS256.

Returns:
str: The string representation of the header, and signature in detached jws format
payload: the payload as received in the request or encoed if {"b4":True} header is passed in the call

Raises:
JWSError: If there is an error signing the token.

Examples:

>>> jws.sign_detached({'a': 'b'}, 'secret', algorithm='HS256')
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8', {'a': 'b'}


>>> jws.sign_detached({'a': 'b'}, 'secret', {"b64": True}, algorithm='HS256')
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8', eyJhIjoiYiJ9

"""

if algorithm not in ALGORITHMS.SUPPORTED:
raise JWSError("Algorithm %s not supported." % algorithm)

if headers:
if "b64" in headers and headers["b64"] is True:
payload = _encode_payload(payload)
headers.update({"crit": ["b64"]})
else:
headers = {"b64": "false"}

encoded_header = _encode_header(algorithm, additional_headers=headers)
signed_output = _sign_header_and_claims(encoded_header, payload, algorithm, key, True)

return signed_output, payload


def verify(token, key, algorithms=None, verify=True, payload=None):
"""Verifies a JWS string's signature.

Args:
token (str): A signed JWS to be verified.
key (str or dict): A key to attempt to verify the payload with. Can be
individual JWK or JWK set.
algorithms (str or list): Valid algorithms that should be used to verify the JWS.
payload (str or dict): Unencoded payload if the token is a detached jws

Returns:
str: The str representation of the payload, assuming the signature is valid.
If the token is a detached jws with "b64" true in the header, the return value will be encoded payload

Raises:
JWSError: If there is an exception verifying a token.
Expand All @@ -65,9 +116,12 @@ def verify(token, key, algorithms, verify=True):
>>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
>>> jws.verify(token, 'secret', algorithms='HS256')

>>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
>>> jws.verify(token, 'secret', algorithms='HS256', payload={"a":"b"})

"""

header, payload, signing_input, signature = _load(token)
header, payload, signing_input, signature = _load(token, payload)

if verify:
_verify_signature(signing_input, header, signature, key, algorithms)
Expand Down Expand Up @@ -126,7 +180,7 @@ def get_unverified_claims(token):


def _encode_header(algorithm, additional_headers=None):
header = {"typ": "JWT", "alg": algorithm}
header = {"typ": "JOSE", "alg": algorithm}

if additional_headers:
header.update(additional_headers)
Expand All @@ -153,7 +207,7 @@ def _encode_payload(payload):
return base64url_encode(payload)


def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key, is_detached=False):
signing_input = b".".join([encoded_header, encoded_claims])
try:
if not isinstance(key, Key):
Expand All @@ -164,12 +218,15 @@ def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):

encoded_signature = base64url_encode(signature)

encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature])
if is_detached:
encoded_string = b"..".join([encoded_header, encoded_signature])
else:
encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature])

return encoded_string.decode("utf-8")


def _load(jwt):
def _load(jwt, payload=None):
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
try:
Expand All @@ -189,10 +246,15 @@ def _load(jwt):
if not isinstance(header, Mapping):
raise JWSError("Invalid header string: must be a json object")

try:
payload = base64url_decode(claims_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid payload padding")
if not payload:
try:
payload = base64url_decode(claims_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid payload padding")
else:
if "b64" in header and header["b64"] is True:
payload = _encode_payload(payload)
signing_input = b"".join([signing_input, payload])

try:
signature = base64url_decode(crypto_segment)
Expand Down
11 changes: 10 additions & 1 deletion tests/test_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jose.backends import RSAKey
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError
from jose.utils import base64url_decode, base64url_encode

try:
from jose.backends.cryptography_backend import CryptographyRSAKey
Expand Down Expand Up @@ -132,7 +133,7 @@ def test_add_headers(self, payload):
expected_headers = {
"test": "header",
"alg": "HS256",
"typ": "JWT",
"typ": "JOSE",
}

token = jws.sign(payload, "secret", headers=additional_headers)
Expand Down Expand Up @@ -307,6 +308,14 @@ def test_jwk_set_failure(self, jwk_set):
with pytest.raises(JWSError):
payload = jws.verify(google_id_token, jwk_set, ALGORITHMS.RS256) # noqa: F841

def test_RSA256_detached(self, payload):
token, payload = jws.sign_detached(payload, rsa_private_key, algorithm=ALGORITHMS.RS256)
assert jws.verify(token, rsa_public_key, payload=payload) == payload

def test_RSA256_detached_encoded(self, payload):
token, encoded_payload = jws.sign_detached(payload, rsa_private_key, {"b64": True}, algorithm=ALGORITHMS.RS256)
assert jws.verify(token, rsa_public_key, payload=payload) == encoded_payload

def test_RSA256(self, payload):
token = jws.sign(payload, rsa_private_key, algorithm=ALGORITHMS.RS256)
assert jws.verify(token, rsa_public_key, ALGORITHMS.RS256) == payload
Expand Down