Skip to content

Commit

Permalink
Merge pull request #40 from alexykot/feature/asymmetric-support
Browse files Browse the repository at this point in the history
Asymmetric crypto support
  • Loading branch information
vimalloc authored May 5, 2017
2 parents fb3e932 + 2216875 commit ee0f562
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 12 deletions.
6 changes: 4 additions & 2 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ General Options:
``JWT_REFRESH_TOKEN_EXPIRES`` How long a refresh token should live before it expires. This
takes a ``datetime.timedelta``, and defaults to 30 days
``JWT_ALGORITHM`` Which algorithm to sign the JWT with. `See here <https://pyjwt.readthedocs.io/en/latest/algorithms.html>`_
for the options. Defaults to ``'HS256'``. Note that Asymmetric
(Public-key) algorithms are not currently supported.
for the options. Defaults to ``'HS256'``.
``JWT_PUBLIC_KEY`` The public key needed for RSA and ECDSA based signing algorithms.
Has to be provided if any of ``RS*`` or ``ES*`` algorithms is used.
PEM format expected.
================================= =========================================


Expand Down
26 changes: 26 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import simplekv
from flask import current_app
from jwt.algorithms import requires_cryptography


class _Config(object):
Expand All @@ -15,6 +16,18 @@ class _Config(object):
object. All of these values are read only.
"""

@property
def is_asymmetric(self):
return self.algorithm in requires_cryptography

@property
def encode_key(self):
return self.secret_key

@property
def decode_key(self):
return self.public_key if self.is_asymmetric else self.secret_key

@property
def token_location(self):
locations = current_app.config['JWT_TOKEN_LOCATION']
Expand Down Expand Up @@ -172,6 +185,17 @@ def secret_key(self):
raise RuntimeError('flask SECRET_KEY must be set')
return key

@property
def public_key(self):
key = None
if self.algorithm in requires_cryptography:
key = current_app.config.get('JWT_PUBLIC_KEY', None)
if not key:
raise RuntimeError('JWT_PUBLIC_KEY must be set to use '
'asymmetric cryptography algorith '
'"{crypto_algorithm}"'.format(crypto_algorithm=self.algorithm))
return key

@property
def cookie_max_age(self):
# Returns the appropiate value for max_age for flask set_cookies. If
Expand All @@ -180,3 +204,5 @@ def cookie_max_age(self):
return None if self.session_cookie else 2147483647 # 2^31

config = _Config()


14 changes: 8 additions & 6 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_REFRESH_TOKEN_EXPIRES', datetime.timedelta(days=30))

# What algorithm to use to sign the token. See here for a list of options:
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py (note
# that public private key is not yet supported in this extension)
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py
app.config.setdefault('JWT_ALGORITHM', 'HS256')

# must be set if using asymmetric cryptography algorithm (RS* or EC*)
app.config.setdefault('JWT_PUBLIC_KEY', None)

# Options for blacklisting/revoking tokens
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_STORE', None)
Expand Down Expand Up @@ -251,15 +253,15 @@ def create_refresh_token(self, identity):
"""
refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.secret_key,
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.refresh_expires,
csrf=config.csrf_protect
)

# If blacklisting is enabled, store this token in our key-value store
if config.blacklist_enabled:
decoded_token = decode_jwt(refresh_token, config.secret_key,
decoded_token = decode_jwt(refresh_token, config.decode_key,
config.algorithm, csrf=config.csrf_protect)
store_token(decoded_token, revoked=False)
return refresh_token
Expand All @@ -282,15 +284,15 @@ def create_access_token(self, identity, fresh=False):
"""
access_token = encode_access_token(
identity=self._user_identity_callback(identity),
secret=config.secret_key,
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.access_expires,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
)
if config.blacklist_enabled and config.blacklist_access_tokens:
decoded_token = decode_jwt(access_token, config.secret_key,
decoded_token = decode_jwt(access_token, config.decode_key,
config.algorithm, csrf=config.csrf_protect)
store_token(decoded_token, revoked=False)
return access_token
2 changes: 1 addition & 1 deletion flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_refresh_token(*args, **kwargs):


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.secret_key, config.algorithm, csrf=True)
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
return token['csrf']


Expand Down
4 changes: 2 additions & 2 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
token = parts[1]

return decode_jwt(token, config.secret_key, config.algorithm, csrf=False)
return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)


def _decode_jwt_from_cookies(request_type):
Expand All @@ -115,7 +115,7 @@ def _decode_jwt_from_cookies(request_type):

decoded_token = decode_jwt(
encoded_token=encoded_token,
secret=config.secret_key,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ alabaster==0.7.9
Babel==2.3.4
click==6.6
coverage==4.2
cryptography==1.8.1
docutils==0.12
Flask==0.11.1
imagesize==0.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
packages=['flask_jwt_extended'],
zip_safe=False,
platforms='any',
install_requires=['Flask', 'PyJWT', 'simplekv'],
install_requires=['Flask', 'PyJWT', 'simplekv', 'cryptography'],
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Web Environment',
Expand Down
24 changes: 24 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ def test_default_configs(self):
self.assertEqual(config.access_expires, timedelta(minutes=15))
self.assertEqual(config.refresh_expires, timedelta(days=30))
self.assertEqual(config.algorithm, 'HS256')
self.assertEqual(config.is_asymmetric, False)
self.assertEqual(config.blacklist_enabled, False)
self.assertEqual(config.blacklist_checks, 'refresh')
self.assertEqual(config.blacklist_access_tokens, False)

self.assertEqual(config.secret_key, self.app.secret_key)
self.assertEqual(config.public_key, None)
self.assertEqual(config.encode_key, self.app.secret_key)
self.assertEqual(config.decode_key, self.app.secret_key)
self.assertEqual(config.cookie_max_age, None)

with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -166,6 +170,15 @@ def test_invalid_config_options(self):
with self.assertRaises(RuntimeError):
config.secret_key

self.app.secret_key = None
with self.assertRaises(RuntimeError):
config.encode_key

self.app.config['JWT_ALGORITHM'] = 'RS256'
self.app.config['JWT_PUBLIC_KEY'] = None
with self.assertRaises(RuntimeError):
config.decode_key

def test_depreciated_options(self):
self.app.config['JWT_CSRF_HEADER_NAME'] = 'Auth'

Expand Down Expand Up @@ -205,3 +218,14 @@ def test_special_config_options(self):
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
self.assertEqual(config.csrf_protect, False)

def test_asymmetric_encryption_key_handling(self):
self.app.secret_key = 'MOCK_RSA_PRIVATE_KEY'
self.app.config['JWT_PUBLIC_KEY'] = 'MOCK_RSA_PUBLIC_KEY'
self.app.config['JWT_ALGORITHM'] = 'RS256'

with self.app.test_request_context():
self.assertEqual(config.is_asymmetric, True)
self.assertEqual(config.secret_key, 'MOCK_RSA_PRIVATE_KEY')
self.assertEqual(config.encode_key, 'MOCK_RSA_PRIVATE_KEY')
self.assertEqual(config.decode_key, 'MOCK_RSA_PUBLIC_KEY')
110 changes: 110 additions & 0 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,113 @@ def test_accessing_endpoint_without_jwt(self):
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 401)
self.assertIn('msg', data)


# random 1024bit RSA keypair
RSA_PRIVATE = """
-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDN+p9a9oMyqRzkae8yLdJcEK0O0WesH6JiMz+KDrpUwAoAM/KP
DnxFnROJDSBHyHEmPVn5x8GqV5lQ9+6l97jdEEcPo6wkshycM82fgcxOmvtAy4Uo
xq/AeplYqplhcUTGVuo4ZldOLmN8ksGmzhWpsOdT0bkYipHCn5sWZxd21QIDAQAB
AoGBAMJ0++KVXXEDZMpjFDWsOq898xNNMHG3/8ZzmWXN161RC1/7qt/RjhLuYtX9
NV9vZRrzyrDcHAKj5pMhLgUzpColKzvdG2vKCldUs2b0c8HEGmjsmpmgoI1Tdf9D
G1QK+q9pKHlbj/MLr4vZPX6xEwAFeqRKlzL30JPD+O6mOXs1AkEA8UDzfadH1Y+H
bcNN2COvCqzqJMwLNRMXHDmUsjHfR2gtzk6D5dDyEaL+O4FLiQCaNXGWWoDTy/HJ
Clh1Z0+KYwJBANqRtJ+RvdgHMq0Yd45MMyy0ODGr1B3PoRbUK8EdXpyUNMi1g3iJ
tXMbLywNkTfcEXZTlbbkVYwrEl6P2N1r42cCQQDb9UQLBEFSTRJE2RRYQ/CL4yt3
cTGmqkkfyr/v19ii2jEpMBzBo8eQnPL+fdvIhWwT3gQfb+WqxD9v10bzcmnRAkEA
mzTgeHd7wg3KdJRtQYTmyhXn2Y3VAJ5SG+3qbCW466NqoCQVCeFwEh75rmSr/Giv
lcDhDZCzFuf3EWNAcmuMfQJARsWfM6q7v2p6vkYLLJ7+VvIwookkr6wymF5Zgb9d
E6oTM2EeUPSyyrj5IdsU2JCNBH1m3JnUflz8p8/NYCoOZg==
-----END RSA PRIVATE KEY-----
"""
RSA_PUBLIC = """
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM36n1r2gzKpHORp7zIt0lwQrQ7RZ6wfomIzP4oOulTACgAz8o8OfEWd
E4kNIEfIcSY9WfnHwapXmVD37qX3uN0QRw+jrCSyHJwzzZ+BzE6a+0DLhSjGr8B6
mViqmWFxRMZW6jhmV04uY3ySwabOFamw51PRuRiKkcKfmxZnF3bVAgMBAAE=
-----END RSA PUBLIC KEY-----
"""

class TestEndpointsWithAssymmetricCrypto(unittest.TestCase):

def setUp(self):
self.app = Flask(__name__)
self.app.secret_key = RSA_PRIVATE
self.app.config['JWT_PUBLIC_KEY'] = RSA_PUBLIC
self.app.config['JWT_ALGORITHM'] = 'RS256'
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()

@self.app.route('/auth/login', methods=['POST'])
def login():
ret = {
'access_token': create_access_token('test', fresh=True),
'refresh_token': create_refresh_token('test')
}
return jsonify(ret), 200

@self.app.route('/auth/refresh', methods=['POST'])
@jwt_refresh_token_required
def refresh():
username = get_jwt_identity()
ret = {'access_token': create_access_token(username, fresh=False)}
return jsonify(ret), 200

@self.app.route('/auth/fresh-login', methods=['POST'])
def fresh_login():
ret = {'access_token': create_access_token('test', fresh=True)}
return jsonify(ret), 200

@self.app.route('/protected')
@jwt_required
def protected():
return jsonify({'msg': "hello world"})

@self.app.route('/fresh-protected')
@fresh_jwt_required
def fresh_protected():
return jsonify({'msg': "fresh hello world"})

def _jwt_post(self, url, jwt):
response = self.client.post(url, content_type='application/json',
headers={'Authorization': 'Bearer {}'.format(jwt)})
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
return status_code, data

def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'):
header_type = '{} {}'.format(header_type, jwt).strip()
response = self.client.get(url, headers={header_name: header_type})
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
return status_code, data

def test_login(self):
response = self.client.post('/auth/login')
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertIn('refresh_token', data)

def test_fresh_login(self):
response = self.client.post('/auth/fresh-login')
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertNotIn('refresh_token', data)

def test_refresh(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']
refresh_token = data['refresh_token']

status_code, data = self._jwt_post('/auth/refresh', refresh_token)
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertNotIn('refresh_token', data)

0 comments on commit ee0f562

Please sign in to comment.