Skip to content

Commit

Permalink
CSRF protection (#798)
Browse files Browse the repository at this point in the history
Closes #793.

* Rename RequestParameters to MultiParams, refs #799
* Allow tuples as well as lists in MultiParams, refs #799
* Use csrftokens when running tests, refs #799
* Use new csrftoken() function, refs simonw/asgi-csrf#7
* Check for Vary: Cookie hedaer, refs simonw/asgi-csrf#8
  • Loading branch information
simonw authored Jun 5, 2020
1 parent d96ac1d commit 84a9c4f
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 19 deletions.
10 changes: 9 additions & 1 deletion datasette/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import asgi_csrf
import collections
import datetime
import hashlib
Expand Down Expand Up @@ -884,7 +885,14 @@ async def setup_db():
await database.table_counts(limit=60 * 60 * 1000)

asgi = AsgiLifespan(
AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db
AsgiTracer(
asgi_csrf.asgi_csrf(
DatasetteRouter(self, routes),
signing_secret=self._secret,
cookie_name="ds_csrftoken",
)
),
on_startup=setup_db,
)
for wrapper in pm.hook.asgi_wrapper(datasette=self):
asgi = wrapper(asgi)
Expand Down
3 changes: 2 additions & 1 deletion datasette/templates/messages_debug.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ <h1>Debug messages</h1>

<p>Set a message:</p>

<form action="/-/messages" method="POST">
<form action="/-/messages" method="post">
<div>
<input type="text" name="message" style="width: 40%">
<div class="select-wrapper">
Expand All @@ -19,6 +19,7 @@ <h1>Debug messages</h1>
<option>all</option>
</select>
</div>
<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">
<input type="submit" value="Add message">
</div>
</form>
Expand Down
1 change: 1 addition & 0 deletions datasette/templates/query.html
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ <h3>Query parameters</h3>
{% endif %}
<p>
<button id="sql-format" type="button" hidden>Format SQL</button>
{% if canned_query %}<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">{% endif %}
<input type="submit" value="Run SQL">
</p>
</form>
Expand Down
3 changes: 3 additions & 0 deletions datasette/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,9 @@ def __init__(self, data):
new_data.setdefault(key, []).append(value)
self._data = new_data

def __repr__(self):
return "<MultiParams: {}>".format(self._data)

def __contains__(self, key):
return key in self._data

Expand Down
1 change: 1 addition & 0 deletions datasette/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async def render(self, templates, request, context=None):
**context,
**{
"database_url": self.database_url,
"csrftoken": request.scope["csrftoken"],
"database_color": self.database_color,
"show_messages": lambda: self.ds._show_messages(request),
"select_templates": [
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_version():
"uvicorn~=0.11",
"aiofiles>=0.4,<0.6",
"janus>=0.4,<0.6",
"asgi-csrf>=0.4",
"PyYAML~=5.3",
"mergedeep>=1.1.1,<1.4.0",
"itsdangerous~=1.1",
Expand Down
38 changes: 33 additions & 5 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datasette.app import Datasette
from datasette.utils import sqlite3
from datasette.utils import sqlite3, MultiParams
from asgiref.testing import ApplicationCommunicator
from asgiref.sync import async_to_sync
from http.cookies import SimpleCookie
Expand Down Expand Up @@ -60,10 +60,35 @@ async def get(

@async_to_sync
async def post(
self, path, post_data=None, allow_redirects=True, redirect_count=0, cookies=None
self,
path,
post_data=None,
allow_redirects=True,
redirect_count=0,
content_type="application/x-www-form-urlencoded",
cookies=None,
csrftoken_from=None,
):
cookies = cookies or {}
post_data = post_data or {}
# Maybe fetch a csrftoken first
if csrftoken_from is not None:
if csrftoken_from is True:
csrftoken_from = path
token_response = await self._request(csrftoken_from)
# Check this had a Vary: Cookie header
assert "Cookie" == token_response.headers["vary"]
csrftoken = token_response.cookies["ds_csrftoken"]
cookies["ds_csrftoken"] = csrftoken
post_data["csrftoken"] = csrftoken
return await self._request(
path, allow_redirects, redirect_count, "POST", cookies, post_data
path,
allow_redirects,
redirect_count,
"POST",
cookies,
post_data,
content_type,
)

async def _request(
Expand All @@ -74,6 +99,7 @@ async def _request(
method="GET",
cookies=None,
post_data=None,
content_type=None,
):
query_string = b""
if "?" in path:
Expand All @@ -84,6 +110,8 @@ async def _request(
else:
raw_path = quote(path, safe="/:,").encode("latin-1")
headers = [[b"host", b"localhost"]]
if content_type:
headers.append((b"content-type", content_type.encode("utf-8")))
if cookies:
sc = SimpleCookie()
for key, value in cookies.items():
Expand Down Expand Up @@ -111,7 +139,7 @@ async def _request(
start = await instance.receive_output(2)
messages.append(start)
assert start["type"] == "http.response.start"
headers = dict(
response_headers = MultiParams(
[(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]]
)
status = start["status"]
Expand All @@ -124,7 +152,7 @@ async def _request(
body += message["body"]
if not message.get("more_body"):
break
response = TestResponse(status, headers, body)
response = TestResponse(status, response_headers, body)
if allow_redirects and response.status in (301, 302):
assert (
redirect_count < self.max_redirects
Expand Down
8 changes: 5 additions & 3 deletions tests/test_canned_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def canned_write_client():

def test_insert(canned_write_client):
response = canned_write_client.post(
"/data/add_name", {"name": "Hello"}, allow_redirects=False
"/data/add_name", {"name": "Hello"}, allow_redirects=False, csrftoken_from=True,
)
assert 302 == response.status
assert "/data/add_name?success" == response.headers["Location"]
Expand All @@ -52,7 +52,7 @@ def test_insert(canned_write_client):

def test_custom_success_message(canned_write_client):
response = canned_write_client.post(
"/data/delete_name", {"rowid": 1}, allow_redirects=False
"/data/delete_name", {"rowid": 1}, allow_redirects=False, csrftoken_from=True
)
assert 302 == response.status
messages = canned_write_client.ds.unsign(
Expand All @@ -62,11 +62,12 @@ def test_custom_success_message(canned_write_client):


def test_insert_error(canned_write_client):
canned_write_client.post("/data/add_name", {"name": "Hello"})
canned_write_client.post("/data/add_name", {"name": "Hello"}, csrftoken_from=True)
response = canned_write_client.post(
"/data/add_name_specify_id",
{"rowid": 1, "name": "Should fail"},
allow_redirects=False,
csrftoken_from=True,
)
assert 302 == response.status
assert "/data/add_name_specify_id?error" == response.headers["Location"]
Expand All @@ -82,6 +83,7 @@ def test_insert_error(canned_write_client):
"/data/add_name_specify_id",
{"rowid": 1, "name": "Should fail"},
allow_redirects=False,
csrftoken_from=True,
)
assert [["ERROR", 3]] == canned_write_client.ds.unsign(
response.cookies["ds_messages"], "messages"
Expand Down
21 changes: 12 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,18 @@ def foo(a, b):
utils.call_with_supported_arguments(foo, a=1)


@pytest.mark.parametrize("data,should_raise", [
([["foo", "bar"], ["foo", "baz"]], False),
([("foo", "bar"), ("foo", "baz")], False),
((["foo", "bar"], ["foo", "baz"]), False),
([["foo", "bar"], ["foo", "baz", "bax"]], True),
({"foo": ["bar", "baz"]}, False),
({"foo": ("bar", "baz")}, False),
({"foo": "bar"}, True),
])
@pytest.mark.parametrize(
"data,should_raise",
[
([["foo", "bar"], ["foo", "baz"]], False),
([("foo", "bar"), ("foo", "baz")], False),
((["foo", "bar"], ["foo", "baz"]), False),
([["foo", "bar"], ["foo", "baz", "bax"]], True),
({"foo": ["bar", "baz"]}, False),
({"foo": ("bar", "baz")}, False),
({"foo": "bar"}, True),
]
)
def test_multi_params(data, should_raise):
if should_raise:
with pytest.raises(AssertionError):
Expand Down

0 comments on commit 84a9c4f

Please sign in to comment.