Skip to content

Commit

Permalink
Merge pull request #3393 from bdarnell/typing
Browse files Browse the repository at this point in the history
Update mypy and various typing improvements
  • Loading branch information
bdarnell authored Jun 7, 2024
2 parents 354869d + 385af83 commit 74e7c98
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 18 deletions.
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ markupsafe==2.1.2
# via jinja2
mccabe==0.7.0
# via flake8
mypy==1.0.1
mypy==1.10.0
# via -r requirements.in
mypy-extensions==0.4.3
mypy-extensions==1.0.0
# via
# black
# mypy
Expand Down Expand Up @@ -111,9 +111,9 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
tox==4.6.0
# via -r requirements.in
types-pycurl==7.45.2.0
types-pycurl==7.45.3.20240421
# via -r requirements.in
typing-extensions==4.4.0
typing-extensions==4.12.1
# via mypy
urllib3==1.26.18
# via requests
Expand Down
5 changes: 4 additions & 1 deletion tornado/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future:
_NO_RESULT = object()


def chain_future(a: "Future[_T]", b: "Future[_T]") -> None:
def chain_future(
a: Union["Future[_T]", "futures.Future[_T]"],
b: Union["Future[_T]", "futures.Future[_T]"],
) -> None:
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
Expand Down
8 changes: 7 additions & 1 deletion tornado/httputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
from asyncio import Future # noqa: F401
import unittest # noqa: F401

# This can be done unconditionally in the base class of HTTPHeaders
# after we drop support for Python 3.8.
StrMutableMapping = collections.abc.MutableMapping[str, str]
else:
StrMutableMapping = collections.abc.MutableMapping

# To be used with str.strip() and related methods.
HTTP_WHITESPACE = " \t"

Expand All @@ -76,7 +82,7 @@ def _normalize_header(name: str) -> str:
return "-".join([w.capitalize() for w in name.split("-")])


class HTTPHeaders(collections.abc.MutableMapping):
class HTTPHeaders(StrMutableMapping):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.
Supports multiple values per key via a pair of new methods,
Expand Down
15 changes: 13 additions & 2 deletions tornado/platform/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
Union,
)

if typing.TYPE_CHECKING:
from typing_extensions import TypeVarTuple, Unpack


class _HasFileno(Protocol):
def fileno(self) -> int:
Expand All @@ -59,6 +62,8 @@ def fileno(self) -> int:

_T = TypeVar("_T")

if typing.TYPE_CHECKING:
_Ts = TypeVarTuple("_Ts")

# Collection of selector thread event loops to shut down on exit.
_selector_loops: Set["SelectorThread"] = set()
Expand Down Expand Up @@ -702,12 +707,18 @@ def close(self) -> None:
self._real_loop.close()

def add_reader(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
*args: "Unpack[_Ts]",
) -> None:
return self._selector.add_reader(fd, callback, *args)

def add_writer(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
*args: "Unpack[_Ts]",
) -> None:
return self._selector.add_writer(fd, callback, *args)

Expand Down
26 changes: 26 additions & 0 deletions tornado/test/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from tornado.concurrent import (
Future,
chain_future,
run_on_executor,
future_set_result_unless_cancelled,
)
Expand All @@ -47,6 +48,31 @@ def test_future_set_result_unless_cancelled(self):
self.assertEqual(fut.result(), 42)


class ChainFutureTest(AsyncTestCase):
@gen_test
async def test_asyncio_futures(self):
fut: Future[int] = Future()
fut2: Future[int] = Future()
chain_future(fut, fut2)
fut.set_result(42)
result = await fut2
self.assertEqual(result, 42)

@gen_test
async def test_concurrent_futures(self):
# A three-step chain: two concurrent futures (showing that both arguments to chain_future
# can be concurrent futures), and then one from a concurrent future to an asyncio future so
# we can use it in await.
fut: futures.Future[int] = futures.Future()
fut2: futures.Future[int] = futures.Future()
fut3: Future[int] = Future()
chain_future(fut, fut2)
chain_future(fut2, fut3)
fut.set_result(42)
result = await fut3
self.assertEqual(result, 42)


# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.

Expand Down
26 changes: 17 additions & 9 deletions tornado/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,15 @@ class RequestHandler(object):
"""

SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS")
SUPPORTED_METHODS: Tuple[str, ...] = (
"GET",
"HEAD",
"POST",
"DELETE",
"PATCH",
"PUT",
"OPTIONS",
)

_template_loaders = {} # type: Dict[str, template.BaseLoader]
_template_loader_lock = threading.Lock()
Expand Down Expand Up @@ -1628,14 +1636,14 @@ def check_xsrf_cookie(self) -> None:
# information please see
# http://www.djangoproject.com/weblog/2011/feb/08/security/
# http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails
token = (
input_token = (
self.get_argument("_xsrf", None)
or self.request.headers.get("X-Xsrftoken")
or self.request.headers.get("X-Csrftoken")
)
if not token:
if not input_token:
raise HTTPError(403, "'_xsrf' argument missing from POST")
_, token, _ = self._decode_xsrf_token(token)
_, token, _ = self._decode_xsrf_token(input_token)
_, expected_token, _ = self._get_raw_xsrf_token()
if not token:
raise HTTPError(403, "'_xsrf' argument has invalid format")
Expand Down Expand Up @@ -1918,7 +1926,7 @@ def render(*args, **kwargs) -> str: # type: ignore
if name not in self._active_modules:
self._active_modules[name] = module(self)
rendered = self._active_modules[name].render(*args, **kwargs)
return rendered
return _unicode(rendered)

return render

Expand Down Expand Up @@ -3355,7 +3363,7 @@ def __init__(self, handler: RequestHandler) -> None:
def current_user(self) -> Any:
return self.handler.current_user

def render(self, *args: Any, **kwargs: Any) -> str:
def render(self, *args: Any, **kwargs: Any) -> Union[str, bytes]:
"""Override in subclasses to return this module's output."""
raise NotImplementedError()

Expand Down Expand Up @@ -3403,12 +3411,12 @@ def render_string(self, path: str, **kwargs: Any) -> bytes:


class _linkify(UIModule):
def render(self, text: str, **kwargs: Any) -> str: # type: ignore
def render(self, text: str, **kwargs: Any) -> str:
return escape.linkify(text, **kwargs)


class _xsrf_form_html(UIModule):
def render(self) -> str: # type: ignore
def render(self) -> str:
return self.handler.xsrf_form_html()


Expand All @@ -3434,7 +3442,7 @@ def __init__(self, handler: RequestHandler) -> None:
self._resource_list = [] # type: List[Dict[str, Any]]
self._resource_dict = {} # type: Dict[str, Dict[str, Any]]

def render(self, path: str, **kwargs: Any) -> bytes: # type: ignore
def render(self, path: str, **kwargs: Any) -> bytes:
def set_resources(**kwargs) -> str: # type: ignore
if path not in self._resource_dict:
self._resource_list.append(kwargs)
Expand Down
2 changes: 1 addition & 1 deletion tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def __init__(
{
"Upgrade": "websocket",
"Connection": "Upgrade",
"Sec-WebSocket-Key": self.key,
"Sec-WebSocket-Key": to_unicode(self.key),
"Sec-WebSocket-Version": "13",
}
)
Expand Down

0 comments on commit 74e7c98

Please sign in to comment.