Skip to content

Commit

Permalink
Refactor _static_request_handler (#2533)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Hopkins <[email protected]>
  • Loading branch information
ChihweiLHBird and ahopkins authored Sep 20, 2022
1 parent 1650331 commit 43ba381
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 33 deletions.
61 changes: 33 additions & 28 deletions sanic/mixins/routes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ast import NodeVisitor, Return, parse
from contextlib import suppress
from email.utils import formatdate
from functools import partial, wraps
from inspect import getsource, signature
from mimetypes import guess_type
from os import path
from pathlib import Path, PurePath
from textwrap import dedent
from time import gmtime, strftime
from typing import (
Any,
Callable,
Expand All @@ -31,7 +31,7 @@
from sanic.log import error_logger
from sanic.models.futures import FutureRoute, FutureStatic
from sanic.models.handler_types import RouteHandler
from sanic.response import HTTPResponse, file, file_stream
from sanic.response import HTTPResponse, file, file_stream, validate_file
from sanic.types import HashableDict


Expand Down Expand Up @@ -790,24 +790,9 @@ def _generate_name(self, *objects) -> str:

return name

async def _static_request_handler(
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
__file_uri__=None,
):
# Merge served directory and requested file if provided
async def _get_file_path(self, file_or_directory, __file_uri__, not_found):
file_path_raw = Path(unquote(file_or_directory))
root_path = file_path = file_path_raw.resolve()
not_found = FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)

if __file_uri__:
# Strip all / that in the beginning of the URL to help prevent
Expand All @@ -834,22 +819,43 @@ async def _static_request_handler(
f"relative_url={__file_uri__}"
)
raise not_found
return file_path

async def _static_request_handler(
self,
file_or_directory,
use_modified_since,
use_content_range,
stream_large_files,
request,
content_type=None,
__file_uri__=None,
):
not_found = FileNotFound(
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)

# Merge served directory and requested file if provided
file_path = await self._get_file_path(
file_or_directory, __file_uri__, not_found
)

try:
headers = {}
# Check if the client has been sent this file before
# and it has not been modified since
stats = None
if use_modified_since:
stats = await stat_async(file_path)
modified_since = strftime(
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
modified_since = stats.st_mtime
response = await validate_file(request.headers, modified_since)
if response:
return response
headers["Last-Modified"] = formatdate(
modified_since, usegmt=True
)
if (
request.headers.getone("if-modified-since", None)
== modified_since
):
return HTTPResponse(status=304)
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
Expand All @@ -864,8 +870,7 @@ async def _static_request_handler(
pass
else:
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
headers.update(_range.headers)

if "content-type" not in headers:
content_type = (
Expand Down
33 changes: 28 additions & 5 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import time

from collections import namedtuple
from datetime import datetime
from email.utils import formatdate
from datetime import datetime, timedelta
from email.utils import formatdate, parsedate_to_datetime
from logging import ERROR, LogRecord
from mimetypes import guess_type
from pathlib import Path
Expand Down Expand Up @@ -665,13 +665,11 @@ async def handler6(request: Request):

with caplog.at_level(ERROR):
_, response = app.test_client.get("/4")
print(response.json)
assert response.status == 200
assert "foo" not in response.text
assert "one" in response.headers
assert response.headers["one"] == "one"

print(response.headers)
assert message_in_records(caplog.records, error_msg2)

with caplog.at_level(ERROR):
Expand Down Expand Up @@ -841,10 +839,10 @@ def file_route_cache(request: Request):
time.sleep(1)
with open(file_path, "a") as f:
f.write("bar\n")

_, response = app.test_client.get(
"/validate", headers={"If-Modified-Since": last_modified}
)

assert response.status == 200
assert response.body == b"foo\nbar\n"

Expand Down Expand Up @@ -921,3 +919,28 @@ def file_route(request: Request, filename: str):
)
assert response.status == 304
assert response.body == b""


@pytest.mark.parametrize(
"file_name", ["test.file", "decode me.txt", "python.png"]
)
def test_file_validating_304_response(
app: Sanic, file_name: str, static_file_directory: str
):
app.static("static", Path(static_file_directory) / file_name)

_, response = app.test_client.get("/static")
assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name)
last_modified = parsedate_to_datetime(response.headers["Last-Modified"])
last_modified += timedelta(seconds=1)
_, response = app.test_client.get(
"/static",
headers={
"if-modified-since": formatdate(
last_modified.timestamp(), usegmt=True
)
},
)
assert response.status == 304
assert response.body == b""

0 comments on commit 43ba381

Please sign in to comment.