diff --git a/distributed/nanny.py b/distributed/nanny.py index 7a14ee6576..859b9f22dc 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -516,7 +516,7 @@ async def _(): await self.instantiate() try: - await wait_for(_(), timeout) + await wait_for(asyncio.shield(_()), timeout) except asyncio.TimeoutError: logger.error( f"Restart timed out after {timeout}s; returning before finished" @@ -745,26 +745,30 @@ async def start(self) -> Status: os.environ.update(self.pre_spawn_env) try: - await self.process.start() - except OSError: - logger.exception("Nanny failed to start process", exc_info=True) - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed - try: - msg = await self._wait_until_connected(uid) - except Exception: - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed - raise + try: + await self.process.start() + except OSError: + # This can only happen if the actual process creation failed, e.g. + # multiprocessing.Process.start failed. This is not tested! + logger.exception("Nanny failed to start process", exc_info=True) + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + try: + msg = await self._wait_until_connected(uid) + except Exception: + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + raise + finally: + self.running.set() if not msg: return self.status self.worker_address = msg["address"] self.worker_dir = msg["dir"] assert self.worker_address self.status = Status.running - self.running.set() return self.status @@ -799,6 +803,7 @@ def mark_stopped(self): msg = self._death_message(self.process.pid, r) logger.info(msg) self.status = Status.stopped + self.running.clear() self.stopped.set() # Release resources self.process.close() @@ -830,11 +835,6 @@ async def kill( """ deadline = time() + timeout - if self.status == Status.stopped: - return - if self.status == Status.stopping: - await self.stopped.wait() - return # If the process is not properly up it will not watch the closing queue # and we may end up leaking this process # Therefore wait for it to be properly started before killing it @@ -842,10 +842,17 @@ async def kill( await self.running.wait() assert self.status in ( + Status.stopping, + Status.stopped, Status.running, Status.failed, # process failed to start, but hasn't been joined yet Status.closing_gracefully, ), self.status + if self.status == Status.stopped: + return + if self.status == Status.stopping: + await self.stopped.wait() + return self.status = Status.stopping logger.info("Nanny asking worker to close. Reason: %s", reason) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 52073c13d2..b05b7dc90c 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -208,25 +208,47 @@ async def test_scheduler_file(): s.stop() -@pytest.mark.xfail( - os.environ.get("MINDEPS") == "true", - reason="Timeout errors with mindeps environment", -) -@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)]) -async def test_nanny_timeout(c, s, a): +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart(c, s, a): + x = await c.scatter(123) + assert await c.submit(lambda: 1) == 1 + + await a.restart() + + while x.status != "cancelled": + await asyncio.sleep(0.1) + + assert await c.submit(lambda: 1) == 1 + + +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart_timeout(c, s, a): x = await c.scatter(123) with captured_logger( logging.getLogger("distributed.nanny"), level=logging.ERROR ) as logger: - await a.restart(timeout=0.1) + await a.restart(timeout=0) out = logger.getvalue() assert "timed out" in out.lower() - start = time() while x.status != "cancelled": await asyncio.sleep(0.1) - assert time() < start + 7 + + assert await c.submit(lambda: 1) == 1 + + +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart_timeout_stress(c, s, a): + x = await c.scatter(123) + restarts = [a.restart(timeout=random.random()) for _ in range(100)] + await asyncio.gather(*restarts) + + while x.status != "cancelled": + await asyncio.sleep(0.1) + + assert await c.submit(lambda: 1) == 1 + assert len(s.workers) == 1 @gen_cluster( @@ -582,6 +604,34 @@ async def test_worker_start_exception(s): assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() +@gen_cluster(nthreads=[]) +async def test_worker_start_exception_while_killing(s): + nanny = Nanny(s.address, worker_class=BrokenWorker) + + async def try_to_kill_nanny(): + while not nanny.process or nanny.process.status != Status.starting: + await asyncio.sleep(0) + await nanny.kill() + + kill_task = asyncio.create_task(try_to_kill_nanny()) + with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: + with raises_with_cause( + RuntimeError, + "Nanny failed to start", + RuntimeError, + "BrokenWorker failed to start", + ): + async with nanny: + pass + await kill_task + assert nanny.status == Status.failed + # ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed` + assert nanny.process is None + assert "Restarting worker" not in logs.getvalue() + # Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.) + assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() + + @gen_cluster(nthreads=[]) async def test_failure_during_worker_initialization(s): with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: