From 3186fb98cc6622f0784ecee088100925a5599e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 22 Nov 2023 23:30:27 +0200 Subject: [PATCH] Implemented voluntary cancellation in worker threads (#629) --- docs/api.rst | 1 + docs/threads.rst | 20 ++++++++ docs/versionhistory.rst | 5 ++ src/anyio/_backends/_asyncio.py | 47 ++++++++++++++--- src/anyio/_backends/_trio.py | 13 +++-- src/anyio/_core/_fileio.py | 44 +++++++++------- src/anyio/_core/_sockets.py | 4 +- src/anyio/abc/_eventloop.py | 7 ++- src/anyio/from_thread.py | 31 +++++++++++- src/anyio/to_thread.py | 22 ++++++-- tests/test_from_thread.py | 90 ++++++++++++++++++++++++++++++++- tests/test_to_thread.py | 21 +++++--- 12 files changed, 260 insertions(+), 45 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 9af315ba..1bd57766 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -61,6 +61,7 @@ Running asynchronous code from other threads .. autofunction:: anyio.from_thread.run .. autofunction:: anyio.from_thread.run_sync +.. autofunction:: anyio.from_thread.check_cancelled .. autofunction:: anyio.from_thread.start_blocking_portal .. autoclass:: anyio.from_thread.BlockingPortal diff --git a/docs/threads.rst b/docs/threads.rst index afb3f183..73de1c06 100644 --- a/docs/threads.rst +++ b/docs/threads.rst @@ -205,3 +205,23 @@ maximum of 40 threads to be spawned. You can adjust this limit like this:: .. note:: AnyIO's default thread pool limiter does not affect the default thread pool executor on :mod:`asyncio`. + +Reacting to cancellation in worker threads +------------------------------------------ + +While there is no mechanism in Python to cancel code running in a thread, AnyIO provides a +mechanism that allows user code to voluntarily check if the host task's scope has been cancelled, +and if it has, raise a cancellation exception. This can be done by simply calling +:func:`from_thread.check_cancelled`:: + + from anyio import to_thread, from_thread + + def sync_function(): + while True: + from_thread.check_cancelled() + print("Not cancelled yet") + sleep(1) + + async def foo(): + with move_on_after(3): + await to_thread.run_sync(sync_function) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index b7432dff..f19c6d87 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 `_. - Call ``trio.to_thread.run_sync()`` using the ``abandon_on_cancel`` keyword argument instead of ``cancellable`` - Removed a checkpoint when exiting a task group + - Renamed the ``cancellable`` argument in ``anyio.to_thread.run_sync()`` to + ``abandon_on_cancel`` (and deprecated the old parameter name) + - Bumped minimum version of Trio to v0.23 +- Added support for voluntary thread cancellation via + ``anyio.from_thread.check_cancelled()`` - Bumped minimum version of trio to v0.23 - Exposed the ``ResourceGuard`` class in the public API - Fixed ``RuntimeError: Runner is closed`` when running higher-scoped async generator diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index a3eb6472..9827e550 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -59,7 +59,7 @@ import sniffio from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc -from .._core._eventloop import claim_worker_thread +from .._core._eventloop import claim_worker_thread, threadlocals from .._core._exceptions import ( BrokenResourceError, BusyResourceError, @@ -783,7 +783,7 @@ def __init__( self.idle_workers = idle_workers self.loop = root_task._loop self.queue: Queue[ - tuple[Context, Callable, tuple, asyncio.Future] | None + tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None ] = Queue(2) self.idle_since = AsyncIOBackend.current_time() self.stopping = False @@ -814,14 +814,17 @@ def run(self) -> None: # Shutdown command received return - context, func, args, future = item + context, func, args, future, cancel_scope = item if not future.cancelled(): result = None exception: BaseException | None = None + threadlocals.current_cancel_scope = cancel_scope try: result = context.run(func, *args) except BaseException as exc: exception = exc + finally: + del threadlocals.current_cancel_scope if not self.loop.is_closed(): self.loop.call_soon_threadsafe( @@ -2045,7 +2048,7 @@ async def run_sync_in_worker_thread( cls, func: Callable[..., T_Retval], args: tuple[Any, ...], - cancellable: bool = False, + abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: await cls.checkpoint() @@ -2062,7 +2065,7 @@ async def run_sync_in_worker_thread( _threadpool_workers.set(workers) async with limiter or cls.current_default_thread_limiter(): - with CancelScope(shield=not cancellable): + with CancelScope(shield=not abandon_on_cancel) as scope: future: asyncio.Future = asyncio.Future() root_task = find_root_task() if not idle_workers: @@ -2091,9 +2094,26 @@ async def run_sync_in_worker_thread( context = copy_context() context.run(sniffio.current_async_library_cvar.set, None) - worker.queue.put_nowait((context, func, args, future)) + if abandon_on_cancel or scope._parent_scope is None: + worker_scope = scope + else: + worker_scope = scope._parent_scope + + worker.queue.put_nowait((context, func, args, future, worker_scope)) return await future + @classmethod + def check_cancelled(cls) -> None: + scope: CancelScope | None = threadlocals.current_cancel_scope + while scope is not None: + if scope.cancel_called: + raise CancelledError(f"Cancelled by cancel scope {id(scope):x}") + + if scope.shield: + return + + scope = scope._parent_scope + @classmethod def run_async_from_thread( cls, @@ -2101,11 +2121,24 @@ def run_async_from_thread( args: tuple[Any, ...], token: object, ) -> T_Retval: + async def task_wrapper(scope: CancelScope) -> T_Retval: + __tracebackhide__ = True + task = cast(asyncio.Task, current_task()) + _task_states[task] = TaskState(None, scope) + scope._tasks.add(task) + try: + return await func(*args) + except CancelledError as exc: + raise concurrent.futures.CancelledError(str(exc)) from None + finally: + scope._tasks.discard(task) + loop = cast(AbstractEventLoop, token) context = copy_context() context.run(sniffio.current_async_library_cvar.set, "asyncio") + wrapper = task_wrapper(threadlocals.current_cancel_scope) f: concurrent.futures.Future[T_Retval] = context.run( - asyncio.run_coroutine_threadsafe, func(*args), loop + asyncio.run_coroutine_threadsafe, wrapper, loop ) return f.result() diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 92eb75cc..3127140c 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -36,7 +36,6 @@ import trio.lowlevel from outcome import Error, Outcome, Value from trio.lowlevel import ( - TrioToken, current_root_task, current_task, wait_readable, @@ -869,7 +868,7 @@ async def run_sync_in_worker_thread( cls, func: Callable[..., T_Retval], args: tuple[Any, ...], - cancellable: bool = False, + abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: def wrapper() -> T_Retval: @@ -879,10 +878,14 @@ def wrapper() -> T_Retval: token = TrioBackend.current_token() return await run_sync( wrapper, - abandon_on_cancel=cancellable, + abandon_on_cancel=abandon_on_cancel, limiter=cast(trio.CapacityLimiter, limiter), ) + @classmethod + def check_cancelled(cls) -> None: + trio.from_thread.check_cancelled() + @classmethod def run_async_from_thread( cls, @@ -890,13 +893,13 @@ def run_async_from_thread( args: tuple[Any, ...], token: object, ) -> T_Retval: - return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token)) + return trio.from_thread.run(func, *args) @classmethod def run_sync_from_thread( cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object ) -> T_Retval: - return trio.from_thread.run_sync(func, *args, trio_token=cast(TrioToken, token)) + return trio.from_thread.run_sync(func, *args) @classmethod def create_blocking_portal(cls) -> abc.BlockingPortal: diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index 40eabe67..f51bf450 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -205,7 +205,9 @@ class _PathIterator(AsyncIterator["Path"]): iterator: Iterator[PathLike[str]] async def __anext__(self) -> Path: - nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True) + nextval = await to_thread.run_sync( + next, self.iterator, None, abandon_on_cancel=True + ) if nextval is None: raise StopAsyncIteration from None @@ -386,17 +388,19 @@ async def cwd(cls) -> Path: return cls(path) async def exists(self) -> bool: - return await to_thread.run_sync(self._path.exists, cancellable=True) + return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True) async def expanduser(self) -> Path: - return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True)) + return Path( + await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True) + ) def glob(self, pattern: str) -> AsyncIterator[Path]: gen = self._path.glob(pattern) return _PathIterator(gen) async def group(self) -> str: - return await to_thread.run_sync(self._path.group, cancellable=True) + return await to_thread.run_sync(self._path.group, abandon_on_cancel=True) async def hardlink_to(self, target: str | pathlib.Path | Path) -> None: if isinstance(target, Path): @@ -413,31 +417,37 @@ def is_absolute(self) -> bool: return self._path.is_absolute() async def is_block_device(self) -> bool: - return await to_thread.run_sync(self._path.is_block_device, cancellable=True) + return await to_thread.run_sync( + self._path.is_block_device, abandon_on_cancel=True + ) async def is_char_device(self) -> bool: - return await to_thread.run_sync(self._path.is_char_device, cancellable=True) + return await to_thread.run_sync( + self._path.is_char_device, abandon_on_cancel=True + ) async def is_dir(self) -> bool: - return await to_thread.run_sync(self._path.is_dir, cancellable=True) + return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True) async def is_fifo(self) -> bool: - return await to_thread.run_sync(self._path.is_fifo, cancellable=True) + return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True) async def is_file(self) -> bool: - return await to_thread.run_sync(self._path.is_file, cancellable=True) + return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True) async def is_mount(self) -> bool: - return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True) + return await to_thread.run_sync( + os.path.ismount, self._path, abandon_on_cancel=True + ) def is_reserved(self) -> bool: return self._path.is_reserved() async def is_socket(self) -> bool: - return await to_thread.run_sync(self._path.is_socket, cancellable=True) + return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True) async def is_symlink(self) -> bool: - return await to_thread.run_sync(self._path.is_symlink, cancellable=True) + return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True) def iterdir(self) -> AsyncIterator[Path]: gen = self._path.iterdir() @@ -450,7 +460,7 @@ async def lchmod(self, mode: int) -> None: await to_thread.run_sync(self._path.lchmod, mode) async def lstat(self) -> os.stat_result: - return await to_thread.run_sync(self._path.lstat, cancellable=True) + return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True) async def mkdir( self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False @@ -493,7 +503,7 @@ async def open( return AsyncFile(fp) async def owner(self) -> str: - return await to_thread.run_sync(self._path.owner, cancellable=True) + return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True) async def read_bytes(self) -> bytes: return await to_thread.run_sync(self._path.read_bytes) @@ -526,7 +536,7 @@ async def replace(self, target: str | pathlib.PurePath | Path) -> Path: async def resolve(self, strict: bool = False) -> Path: func = partial(self._path.resolve, strict=strict) - return Path(await to_thread.run_sync(func, cancellable=True)) + return Path(await to_thread.run_sync(func, abandon_on_cancel=True)) def rglob(self, pattern: str) -> AsyncIterator[Path]: gen = self._path.rglob(pattern) @@ -542,12 +552,12 @@ async def samefile( other_path = other_path._path return await to_thread.run_sync( - self._path.samefile, other_path, cancellable=True + self._path.samefile, other_path, abandon_on_cancel=True ) async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: func = partial(os.stat, follow_symlinks=follow_symlinks) - return await to_thread.run_sync(func, self._path, cancellable=True) + return await to_thread.run_sync(func, self._path, abandon_on_cancel=True) async def symlink_to( self, diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 75e79618..08667ce9 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -693,9 +693,9 @@ async def setup_unix_local_socket( if path_str is not None: try: - await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True) + await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True) if mode is not None: - await to_thread.run_sync(chmod, path_str, mode, cancellable=True) + await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True) except BaseException: raw_socket.close() raise diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index f9aae376..7ff05f7f 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -171,11 +171,16 @@ async def run_sync_in_worker_thread( cls, func: Callable[..., T_Retval], args: tuple[Any, ...], - cancellable: bool = False, + abandon_on_cancel: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: pass + @classmethod + @abstractmethod + def check_cancelled(cls) -> None: + pass + @classmethod @abstractmethod def run_async_from_thread( diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index cb6e695e..d8aa2f13 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -21,6 +21,7 @@ from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals from ._core._synchronization import Event from ._core._tasks import CancelScope, create_task_group +from .abc import AsyncBackend from .abc._tasks import TaskStatus T_Retval = TypeVar("T_Retval") @@ -40,7 +41,9 @@ def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: async_backend = threadlocals.current_async_backend token = threadlocals.current_token except AttributeError: - raise RuntimeError("This function can only be run from an AnyIO worker thread") + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None return async_backend.run_async_from_thread(func, args, token=token) @@ -58,7 +61,9 @@ def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: async_backend = threadlocals.current_async_backend token = threadlocals.current_token except AttributeError: - raise RuntimeError("This function can only be run from an AnyIO worker thread") + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None return async_backend.run_sync_from_thread(func, args, token=token) @@ -422,3 +427,25 @@ async def run_portal() -> None: pass run_future.result() + + +def check_cancelled() -> None: + """ + Check if the cancel scope of the host task's running the current worker thread has + been cancelled. + + If the host task's current cancel scope has indeed been cancelled, the + backend-specific cancellation exception will be raised. + + :raises RuntimeError: if the current thread was not spawned by + :func:`.to_thread.run_sync` + + """ + try: + async_backend: AsyncBackend = threadlocals.current_async_backend + except AttributeError: + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None + + async_backend.check_cancelled() diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index a7fafcda..d9a632e8 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import TypeVar +from warnings import warn from ._core._eventloop import get_async_backend from .abc import CapacityLimiter @@ -12,7 +13,8 @@ async def run_sync( func: Callable[..., T_Retval], *args: object, - cancellable: bool = False, + abandon_on_cancel: bool = False, + cancellable: bool | None = None, limiter: CapacityLimiter | None = None, ) -> T_Retval: """ @@ -24,14 +26,28 @@ async def run_sync( :param func: a callable :param args: positional arguments for the callable - :param cancellable: ``True`` to allow cancellation of the operation + :param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run + unchecked on own) if the host task is cancelled, ``False`` to ignore + cancellations in the host task until the operation has completed in the worker + thread + :param cancellable: deprecated alias of ``abandon_on_cancel``; will override + ``abandon_on_cancel`` if both parameters are passed :param limiter: capacity limiter to use to limit the total amount of threads running (if omitted, the default limiter is used) :return: an awaitable that yields the return value of the function. """ + if cancellable is not None: + abandon_on_cancel = cancellable + warn( + "The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is " + "deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead", + DeprecationWarning, + stacklevel=2, + ) + return await get_async_backend().run_sync_in_worker_thread( - func, args, cancellable=cancellable, limiter=limiter + func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter ) diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index ab577f7e..0e580462 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -6,7 +6,7 @@ import time from collections.abc import Awaitable, Callable from concurrent import futures -from concurrent.futures import CancelledError +from concurrent.futures import CancelledError, Future from contextlib import asynccontextmanager, suppress from contextvars import ContextVar from typing import Any, AsyncGenerator, Literal, NoReturn, TypeVar @@ -16,8 +16,10 @@ from _pytest.logging import LogCaptureFixture from anyio import ( + CancelScope, Event, create_task_group, + fail_after, from_thread, get_all_backends, get_cancelled_exc_class, @@ -65,6 +67,92 @@ def thread_worker_sync(func: Callable[..., T_Retval], *args: Any) -> T_Retval: return from_thread.run_sync(func, *args) +@pytest.mark.parametrize("cancel", [True, False]) +async def test_thread_cancelled(cancel: bool) -> None: + event = threading.Event() + thread_finished_future: Future[None] = Future() + + def sync_function() -> None: + event.wait(3) + try: + from_thread.check_cancelled() + except BaseException as exc: + thread_finished_future.set_exception(exc) + else: + thread_finished_future.set_result(None) + + async with create_task_group() as tg: + tg.start_soon(to_thread.run_sync, sync_function) + await wait_all_tasks_blocked() + if cancel: + tg.cancel_scope.cancel() + + event.set() + + if cancel: + with pytest.raises(get_cancelled_exc_class()): + thread_finished_future.result(3) + else: + thread_finished_future.result(3) + + +async def test_thread_cancelled_and_abandoned() -> None: + event = threading.Event() + thread_finished_future: Future[None] = Future() + + def sync_function() -> None: + event.wait(3) + try: + from_thread.check_cancelled() + except BaseException as exc: + thread_finished_future.set_exception(exc) + else: + thread_finished_future.set_result(None) + + async with create_task_group() as tg: + tg.start_soon(lambda: to_thread.run_sync(sync_function, abandon_on_cancel=True)) + await wait_all_tasks_blocked() + tg.cancel_scope.cancel() + + event.set() + with pytest.raises(get_cancelled_exc_class()): + thread_finished_future.result(3) + + +async def test_cancelscope_propagation() -> None: + async def async_time_bomb() -> None: + cancel_scope.cancel() + with fail_after(1): + await sleep(3) + + with CancelScope() as cancel_scope: + await to_thread.run_sync(from_thread.run, async_time_bomb) + + assert cancel_scope.cancelled_caught + + +async def test_cancelscope_propagation_when_abandoned() -> None: + host_cancelled_event = Event() + completed_event = Event() + + async def async_time_bomb() -> None: + cancel_scope.cancel() + with fail_after(3): + await host_cancelled_event.wait() + + completed_event.set() + + with CancelScope() as cancel_scope: + await to_thread.run_sync( + from_thread.run, async_time_bomb, abandon_on_cancel=True + ) + + assert cancel_scope.cancelled_caught + host_cancelled_event.set() + with fail_after(3): + await completed_event.wait() + + class TestRunAsyncFromThread: async def test_run_corofunc_from_thread(self) -> None: result = await to_thread.run_sync(thread_worker_async, async_add, 1, 2) diff --git a/tests/test_to_thread.py b/tests/test_to_thread.py index 223f3ecc..6dc46ba7 100644 --- a/tests/test_to_thread.py +++ b/tests/test_to_thread.py @@ -85,12 +85,14 @@ async def task_worker() -> None: @pytest.mark.parametrize( - "cancellable, expected_last_active", - [(False, "task"), (True, "thread")], - ids=["uncancellable", "cancellable"], + "abandon_on_cancel, expected_last_active", + [ + pytest.param(False, "task", id="noabandon"), + pytest.param(True, "thread", id="abandon"), + ], ) async def test_cancel_worker_thread( - cancellable: bool, expected_last_active: str + abandon_on_cancel: bool, expected_last_active: str ) -> None: """ Test that when a task running a worker thread is cancelled, the cancellation is not @@ -109,7 +111,7 @@ def thread_worker() -> None: async def task_worker() -> None: nonlocal last_active try: - await to_thread.run_sync(thread_worker, cancellable=cancellable) + await to_thread.run_sync(thread_worker, abandon_on_cancel=abandon_on_cancel) finally: last_active = "task" @@ -132,7 +134,7 @@ def wait_event() -> None: future.set_result(event.wait(1)) async with create_task_group() as tg: - tg.start_soon(partial(to_thread.run_sync, cancellable=True), wait_event) + tg.start_soon(partial(to_thread.run_sync, abandon_on_cancel=True), wait_event) await wait_all_tasks_blocked() tg.cancel_scope.cancel() @@ -140,6 +142,11 @@ def wait_event() -> None: assert future.result(1) +async def test_deprecated_cancellable_param() -> None: + with pytest.warns(DeprecationWarning, match="The `cancellable=`"): + await to_thread.run_sync(bool, cancellable=True) + + async def test_contextvar_propagation() -> None: var = ContextVar("var", default=1) var.set(6) @@ -158,7 +165,7 @@ async def test_asyncio_cancel_native_task() -> None: async def run_in_thread() -> None: nonlocal task task = asyncio.current_task() - await to_thread.run_sync(time.sleep, 0.2, cancellable=True) + await to_thread.run_sync(time.sleep, 0.2, abandon_on_cancel=True) async with create_task_group() as tg: tg.start_soon(run_in_thread)