Skip to content

Commit

Permalink
shield kill
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Sep 2, 2024
1 parent fe0a339 commit e5fb18d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
4 changes: 2 additions & 2 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def _():
await self.instantiate()

try:
await wait_for(_(), timeout)
await wait_for(asyncio.shield(_()), timeout)
except asyncio.TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
Expand Down Expand Up @@ -769,7 +769,6 @@ async def start(self) -> Status:
self.worker_dir = msg["dir"]
assert self.worker_address
self.status = Status.running
self.running.set()

return self.status

Expand Down Expand Up @@ -804,6 +803,7 @@ def mark_stopped(self):
msg = self._death_message(self.process.pid, r)
logger.info(msg)
self.status = Status.stopped
self.running.clear()
self.stopped.set()
# Release resources
self.process.close()
Expand Down
40 changes: 31 additions & 9 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,47 @@ async def test_scheduler_file():
s.stop()


@pytest.mark.xfail(
os.environ.get("MINDEPS") == "true",
reason="Timeout errors with mindeps environment",
)
@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)])
async def test_nanny_timeout(c, s, a):
@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart(c, s, a):
x = await c.scatter(123)
assert await c.submit(lambda: 1) == 1

await a.restart()

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout(c, s, a):
x = await c.scatter(123)
with captured_logger(
logging.getLogger("distributed.nanny"), level=logging.ERROR
) as logger:
await a.restart(timeout=0.1)
await a.restart(timeout=0)

out = logger.getvalue()
assert "timed out" in out.lower()

start = time()
while x.status != "cancelled":
await asyncio.sleep(0.1)
assert time() < start + 7

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout_stress(c, s, a):
x = await c.scatter(123)
restarts = [a.restart(timeout=random.random() * 0.1) for _ in range(100)]
await asyncio.gather(*restarts)

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1
assert len(s.workers) == 1


@gen_cluster(
Expand Down

0 comments on commit e5fb18d

Please sign in to comment.