Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogeneously schedule P2P's unpack tasks #8873

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,14 +1425,8 @@ class TaskState:
#: be rejected.
run_id: int | None

#: Whether to consider this task rootish in the context of task queueing
#: True
#: Always consider this task rootish
#: False
#: Never consider this task rootish
#: None
#: Use a heuristic to determine whether this task should be considered rootish
_rootish: bool | None
#: Whether to allow queueing this task if it is rootish
_queueable: bool

#: Cached hash of :attr:`~TaskState.client_key`
_hash: int
Expand Down Expand Up @@ -1489,7 +1483,7 @@ def __init__(
self.metadata = None
self.annotations = None
self.erred_on = None
self._rootish = None
self._queueable = True
self.run_id = None
self.group = group
group.add(self)
Expand Down Expand Up @@ -2286,7 +2280,7 @@ def decide_worker_rootish_queuing_disabled(
"""
if self.validate:
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION)
assert math.isinf(self.WORKER_SATURATION) or not ts._queueable

pool = self.idle.values() if self.idle else self.running
if not pool:
Expand Down Expand Up @@ -2452,7 +2446,7 @@ def _transition_waiting_processing(self, key: Key, stimulus_id: str) -> RecsMsgs
# removed, there should only be one, which combines co-assignment and
# queuing. Eventually, special-casing root tasks might be removed entirely,
# with better heuristics.
if math.isinf(self.WORKER_SATURATION):
if math.isinf(self.WORKER_SATURATION) or not ts._queueable:
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
Expand Down Expand Up @@ -3090,8 +3084,6 @@ def is_rootish(self, ts: TaskState) -> bool:
and have few or no dependencies. Tasks may also be explicitly marked as rootish
to override this heuristic.
"""
if ts._rootish is not None:
return ts._rootish
if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
return False
tg = ts.group
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _ensure_output_tasks_are_non_rootish(self, spec: ShuffleSpec) -> None:
"""
barrier = self.scheduler.tasks[barrier_key(spec.id)]
for dependent in barrier.dependents:
dependent._rootish = False
dependent._queueable = False

@log_errors()
def _set_restriction(self, ts: TaskState, worker: str) -> None:
Expand Down
36 changes: 36 additions & 0 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import math
import random
import warnings
from collections import defaultdict

import pytest

from distributed.diagnostics.plugin import SchedulerPlugin

np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

Expand Down Expand Up @@ -1488,3 +1491,36 @@ def test_calculate_prechunking_splitting(old, new, expected):
# _calculate_prechunking does not concatenate on object
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == expected


@gen_cluster(client=True, nthreads=[("", 1)] * 4, config={"array.chunk-size": "1 B"})
async def test_homogeneously_schedule_unpack(c, s, *ws):
class SchedulingTrackerPlugin(SchedulerPlugin):
async def start(self, scheduler):
self.scheduler = scheduler
self.counts = defaultdict(int)
self.seen = set()

def transition(self, key, start, finish, *args, stimulus_id, **kwargs):
if key in self.seen:
return

if not isinstance(key, tuple) or not isinstance(key[0], str):
return

if not key[0].startswith("rechunk-p2p"):
return

if start != "waiting" or finish != "processing":
return

self.seen.add(key)
self.counts[self.scheduler.tasks[key].processing_on.address] += 1

await c.register_plugin(SchedulingTrackerPlugin(), name="tracker")
res = da.random.random((100, 100), chunks=(1, -1)).rechunk((-1, 1))
await c.compute(res)
counts = s.plugins["tracker"].counts
min_count = min(counts.values())
max_count = max(counts.values())
assert min_count >= max_count, counts
27 changes: 0 additions & 27 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2685,33 +2685,6 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
return await super().barrier(id, run_id, consistent)


@gen_cluster(client=True)
async def test_unpack_is_non_rootish(c, s, a, b):
with pytest.warns(UserWarning):
scheduler_plugin = BlockedBarrierShuffleSchedulerPlugin(s)
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-21",
dtypes={"x": float, "y": float},
freq="10 s",
)
df = df.shuffle("x")
result = c.compute(df)

await scheduler_plugin.in_barrier.wait()

unpack_tss = [ts for key, ts in s.tasks.items() if key_split(key) == UNPACK_PREFIX]
assert len(unpack_tss) == 20
assert not any(s.is_rootish(ts) for ts in unpack_tss)
del unpack_tss
scheduler_plugin.block_barrier.set()
result = await result

await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)


class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connects=0, **kwargs):
self.attempts = 0
Expand Down
24 changes: 0 additions & 24 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,30 +284,6 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()


@gen_cluster(
client=True,
nthreads=[],
)
async def test_override_is_rootish(c, s):
x = c.submit(lambda x: x + 1, 1, key="x")
await async_poll_for(lambda: "x" in s.tasks, timeout=5)
ts_x = s.tasks["x"]
assert ts_x._rootish is None
assert s.is_rootish(ts_x)

ts_x._rootish = False
assert not s.is_rootish(ts_x)

y = c.submit(lambda y: y + 1, 1, key="y", workers=["not-existing"])
await async_poll_for(lambda: "y" in s.tasks, timeout=5)
ts_y = s.tasks["y"]
assert ts_y._rootish is None
assert not s.is_rootish(ts_y)

ts_y._rootish = True
assert s.is_rootish(ts_y)


@pytest.mark.skipif(
QUEUING_ON_BY_DEFAULT,
reason="Not relevant with queuing on; see https://github.com/dask/distributed/issues/7204",
Expand Down
Loading