Skip to content

Commit

Permalink
Support for Flask 2 async views
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jun 24, 2021
1 parent f6c6f96 commit dc6de2d
Show file tree
Hide file tree
Showing 6 changed files with 547 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/flask_httpauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import wraps
from hashlib import md5
from random import Random, SystemRandom
from flask import request, make_response, session, g, Response
from flask import request, make_response, session, g, Response, current_app
from werkzeug.datastructures import Authorization


Expand Down Expand Up @@ -57,7 +57,7 @@ def get_user_roles(self, f):
def error_handler(self, f):
@wraps(f)
def decorated(*args, **kwargs):
res = f(*args, **kwargs)
res = self.ensure_sync(f)(*args, **kwargs)
check_status_code = not isinstance(res, (tuple, Response))
res = make_response(res)
if check_status_code and res.status_code == 200:
Expand Down Expand Up @@ -105,7 +105,8 @@ def get_auth_password(self, auth):
password = None

if auth and auth.username:
password = self.get_password_callback(auth.username)
password = self.ensure_sync(self.get_password_callback)(
auth.username)

return password

Expand All @@ -120,7 +121,7 @@ def authorize(self, role, user, auth):
user = auth
if self.get_user_roles_callback is None: # pragma: no cover
raise ValueError('get_user_roles callback is not defined')
user_roles = self.get_user_roles_callback(user)
user_roles = self.ensure_sync(self.get_user_roles_callback)(user)
if user_roles is None:
user_roles = {}
elif not isinstance(user_roles, (list, tuple)):
Expand Down Expand Up @@ -170,7 +171,7 @@ def decorated(*args, **kwargs):

g.flask_httpauth_user = user if user is not True \
else auth.username if auth else None
return f(*args, **kwargs)
return self.ensure_sync(f)(*args, **kwargs)
return decorated

if f:
Expand All @@ -187,6 +188,12 @@ def current_user(self):
if hasattr(g, 'flask_httpauth_user'):
return g.flask_httpauth_user

def ensure_sync(self, f):
try:
return current_app.ensure_sync(f)
except AttributeError: # pragma: no cover
return f


class HTTPBasicAuth(HTTPAuth):
def __init__(self, scheme=None, realm=None):
Expand Down Expand Up @@ -232,15 +239,17 @@ def authenticate(self, auth, stored_password):
username = ""
client_password = ""
if self.verify_password_callback:
return self.verify_password_callback(username, client_password)
return self.ensure_sync(self.verify_password_callback)(
username, client_password)
if not auth:
return
if self.hash_password_callback:
try:
client_password = self.hash_password_callback(client_password)
client_password = self.ensure_sync(
self.hash_password_callback)(client_password)
except TypeError:
client_password = self.hash_password_callback(username,
client_password)
client_password = self.ensure_sync(
self.hash_password_callback)(username, client_password)
return auth.username if client_password is not None and \
stored_password is not None and \
hmac.compare_digest(client_password, stored_password) else None
Expand Down Expand Up @@ -360,7 +369,7 @@ def authenticate(self, auth, stored_password):
else:
token = ""
if self.verify_token_callback:
return self.verify_token_callback(token)
return self.ensure_sync(self.verify_token_callback)(token)


class MultiAuth(object):
Expand All @@ -383,9 +392,8 @@ def decorated(*args, **kwargs):
if auth.is_compatible_auth(request.headers):
selected_auth = auth
break
return selected_auth.login_required(role=role,
optional=optional
)(f)(*args, **kwargs)
return selected_auth.login_required(
role=role, optional=optional)(f)(*args, **kwargs)
return decorated

if f:
Expand Down
90 changes: 90 additions & 0 deletions tests/test_basic_verify_password_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import sys
import unittest
import base64
from flask import Flask, g
from flask_httpauth import HTTPBasicAuth
import pytest


@pytest.mark.skipif(sys.version_info < (3, 7), reason='requires python3.7')
class HTTPAuthTestCase(unittest.TestCase):
use_old_style_callback = False

def setUp(self):
app = Flask(__name__)
app.config['SECRET_KEY'] = 'my secret'

basic_verify_auth = HTTPBasicAuth()

@basic_verify_auth.verify_password
async def basic_verify_auth_verify_password(username, password):
if self.use_old_style_callback:
g.anon = False
if username == 'john':
return password == 'hello'
elif username == 'susan':
return password == 'bye'
elif username == '':
g.anon = True
return True
return False
else:
g.anon = False
if username == 'john' and password == 'hello':
return 'john'
elif username == 'susan' and password == 'bye':
return 'susan'
elif username == '':
g.anon = True
return ''

@basic_verify_auth.error_handler
async def error_handler():
self.assertIsNone(basic_verify_auth.current_user())
return 'error', 403 # use a custom error status

@app.route('/')
async def index():
return 'index'

@app.route('/basic-verify')
@basic_verify_auth.login_required
async def basic_verify_auth_route():
if self.use_old_style_callback:
return 'basic_verify_auth:' + basic_verify_auth.username() + \
' anon:' + str(g.anon)
else:
return 'basic_verify_auth:' + \
basic_verify_auth.current_user() + ' anon:' + str(g.anon)

self.app = app
self.basic_verify_auth = basic_verify_auth
self.client = app.test_client()

def test_verify_auth_login_valid(self):
creds = base64.b64encode(b'susan:bye').decode('utf-8')
response = self.client.get(
'/basic-verify', headers={'Authorization': 'Basic ' + creds})
self.assertEqual(response.data, b'basic_verify_auth:susan anon:False')

def test_verify_auth_login_empty(self):
response = self.client.get('/basic-verify')
self.assertEqual(response.data, b'basic_verify_auth: anon:True')

def test_verify_auth_login_invalid(self):
creds = base64.b64encode(b'john:bye').decode('utf-8')
response = self.client.get(
'/basic-verify', headers={'Authorization': 'Basic ' + creds})
self.assertEqual(response.status_code, 403)
self.assertTrue('WWW-Authenticate' in response.headers)

def test_verify_auth_login_malformed_password(self):
creds = 'eyJhbGciOieyJp=='
response = self.client.get('/basic-verify',
headers={'Authorization': 'Basic ' + creds})
self.assertEqual(response.status_code, 403)
self.assertTrue('WWW-Authenticate' in response.headers)


class HTTPAuthTestCaseOldStyle(HTTPAuthTestCase):
use_old_style_callback = True
151 changes: 151 additions & 0 deletions tests/test_multi_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import base64
import sys
import unittest
from flask import Flask
from flask_httpauth import HTTPBasicAuth, HTTPTokenAuth, MultiAuth
import pytest


@pytest.mark.skipif(sys.version_info < (3, 7), reason='requires python3.7')
class HTTPAuthTestCase(unittest.TestCase):
def setUp(self):
app = Flask(__name__)
app.config['SECRET_KEY'] = 'my secret'

basic_auth = HTTPBasicAuth()
token_auth = HTTPTokenAuth('MyToken')
custom_token_auth = HTTPTokenAuth(header='X-Token')
multi_auth = MultiAuth(basic_auth, token_auth, custom_token_auth)

@basic_auth.verify_password
async def verify_password(username, password):
if username == 'john' and password == 'hello':
return 'john'

@basic_auth.get_user_roles
async def get_basic_role(username):
if username == 'john':
return ['foo', 'bar']

@token_auth.verify_token
async def verify_token(token):
return token == 'this-is-the-token!'

@token_auth.get_user_roles
async def get_token_role(auth):
if auth['token'] == 'this-is-the-token!':
return 'foo'
return

@token_auth.error_handler
async def error_handler():
return 'error', 401, {'WWW-Authenticate': 'MyToken realm="Foo"'}

@custom_token_auth.verify_token
async def verify_custom_token(token):
return token == 'this-is-the-custom-token!'

@custom_token_auth.get_user_roles
async def get_custom_token_role(auth):
if auth['token'] == 'this-is-the-custom-token!':
return 'foo'
return

@app.route('/')
async def index():
return 'index'

@app.route('/protected')
@multi_auth.login_required
async def auth_route():
return 'access granted:' + str(multi_auth.current_user())

@app.route('/protected-with-role')
@multi_auth.login_required(role='foo')
async def auth_role_route():
return 'role access granted'

self.app = app
self.client = app.test_client()

def test_multi_auth_prompt(self):
response = self.client.get('/protected')
self.assertEqual(response.status_code, 401)
self.assertTrue('WWW-Authenticate' in response.headers)
self.assertEqual(response.headers['WWW-Authenticate'],
'Basic realm="Authentication Required"')

def test_multi_auth_login_valid_basic(self):
creds = base64.b64encode(b'john:hello').decode('utf-8')
response = self.client.get(
'/protected', headers={'Authorization': 'Basic ' + creds})
self.assertEqual(response.data.decode('utf-8'), 'access granted:john')

def test_multi_auth_login_invalid_basic(self):
creds = base64.b64encode(b'john:bye').decode('utf-8')
response = self.client.get(
'/protected', headers={'Authorization': 'Basic ' + creds})
self.assertEqual(response.status_code, 401)
self.assertTrue('WWW-Authenticate' in response.headers)
self.assertEqual(response.headers['WWW-Authenticate'],
'Basic realm="Authentication Required"')

def test_multi_auth_login_valid_token(self):
response = self.client.get(
'/protected', headers={'Authorization':
'MyToken this-is-the-token!'})
self.assertEqual(response.data.decode('utf-8'), 'access granted:None')

def test_multi_auth_login_invalid_token(self):
response = self.client.get(
'/protected', headers={'Authorization':
'MyToken this-is-not-the-token!'})
self.assertEqual(response.status_code, 401)
self.assertTrue('WWW-Authenticate' in response.headers)
self.assertEqual(response.headers['WWW-Authenticate'],
'MyToken realm="Foo"')

def test_multi_auth_login_valid_custom_token(self):
response = self.client.get(
'/protected', headers={'X-Token': 'this-is-the-custom-token!'})
self.assertEqual(response.data.decode('utf-8'), 'access granted:None')

def test_multi_auth_login_invalid_custom_token(self):
response = self.client.get(
'/protected', headers={'X-Token': 'this-is-not-the-token!'})
self.assertEqual(response.status_code, 401)
self.assertTrue('WWW-Authenticate' in response.headers)
self.assertEqual(response.headers['WWW-Authenticate'],
'Bearer realm="Authentication Required"')

def test_multi_auth_login_invalid_scheme(self):
response = self.client.get(
'/protected', headers={'Authorization': 'Foo this-is-the-token!'})
self.assertEqual(response.status_code, 401)
self.assertTrue('WWW-Authenticate' in response.headers)
self.assertEqual(response.headers['WWW-Authenticate'],
'Basic realm="Authentication Required"')

def test_multi_malformed_header(self):
response = self.client.get(
'/protected', headers={'Authorization': 'token-without-scheme'})
self.assertEqual(response.status_code, 401)

def test_multi_auth_login_valid_basic_role(self):
creds = base64.b64encode(b'john:hello').decode('utf-8')
response = self.client.get(
'/protected-with-role', headers={'Authorization':
'Basic ' + creds})
self.assertEqual(response.data.decode('utf-8'), 'role access granted')

def test_multi_auth_login_valid_token_role(self):
response = self.client.get(
'/protected-with-role', headers={'Authorization':
'MyToken this-is-the-token!'})
self.assertEqual(response.data.decode('utf-8'), 'role access granted')

def test_multi_auth_login_valid_custom_token_role(self):
response = self.client.get(
'/protected-with-role', headers={'X-Token':
'this-is-the-custom-token!'})
self.assertEqual(response.data.decode('utf-8'), 'role access granted')
Loading

0 comments on commit dc6de2d

Please sign in to comment.