Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure client_desires_keys does not corrupt Scheduler state #8827

Merged
merged 10 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _try_bind_worker_client(self):
if not self._client:
try:
self._client = get_client()
self._future = Future(self._key, inform=False)
self._future = Future(self._key, self._client)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand Down
51 changes: 22 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,12 @@ class Future(WrappedKey):
# Make sure this stays unique even across multiple processes or hosts
_uid = uuid.uuid4().hex

def __init__(self, key, client=None, inform=True, state=None, _id=None):
def __init__(self, key, client=None, state=None, _id=None):
self.key = key
self._cleared = False
self._client = client
self._id = _id or (Future._uid, next(Future._counter))
self._input_state = state
self._inform = inform
self._state = None
self._bind_late()

Expand All @@ -312,13 +311,11 @@ def client(self):
self._bind_late()
return self._client

def bind_client(self, client):
self._client = client
self._bind_late()

def _bind_late(self):
if not self._client:
try:
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self.key)
self._generation = self._client.generation
Expand All @@ -328,15 +325,6 @@ def _bind_late(self):
else:
self._state = self._client.futures[self.key] = FutureState(self.key)

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self.key],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
Expand Down Expand Up @@ -588,13 +576,8 @@ def release(self):
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

@staticmethod
def make_future(key, id):
# Can't use kwargs in pickle __reduce__ methods
return Future(key=key, _id=id)

def __reduce__(self) -> str | tuple[Any, ...]:
return Future.make_future, (self.key, self._id)
return Future, (self.key,)

def __dask_tokenize__(self):
return (type(self).__name__, self.key, self._id)
Expand Down Expand Up @@ -2161,7 +2144,7 @@ def submit(

with self._refcount_lock:
if key in self.futures:
return Future(key, self, inform=False)
return Future(key, self)

if allow_other_workers and workers is None:
raise ValueError("Only use allow_other_workers= if using workers=")
Expand Down Expand Up @@ -2661,7 +2644,7 @@ async def _scatter(
timeout=timeout,
)

out = {k: Future(k, self, inform=False) for k in data}
out = {k: Future(k, self) for k in data}
for key, typ in types.items():
self.futures[key].finish(type=typ)

Expand Down Expand Up @@ -2969,12 +2952,14 @@ def list_datasets(self, **kwargs):
async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
else:
return default
for fut in futures_of(out["data"]):
fut.bind_client(self)
self._inform_scheduler_of_futures()
return out["data"]

def get_dataset(self, name, default=no_default, **kwargs):
Expand Down Expand Up @@ -3300,6 +3285,14 @@ def _get_computation_code(

return tuple(reversed(code))

def _inform_scheduler_of_futures(self):
self._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": list(self.refcount),
}
)

def _graph_to_futures(
self,
dsk,
Expand Down Expand Up @@ -3348,7 +3341,7 @@ def _graph_to_futures(
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
futures = {key: Future(key, self) for key in keyset}
# Circular import
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle
Expand Down Expand Up @@ -3507,7 +3500,7 @@ def _optimize_insert_futures(self, dsk, keys):
if not changed:
changed = True
dsk = ensure_dict(dsk)
dsk[key] = Future(key, self, inform=False)
dsk[key] = Future(key, self)

if changed:
dsk, _ = dask.optimization.cull(dsk, keys)
Expand Down Expand Up @@ -6092,7 +6085,7 @@ def futures_of(o, client=None):
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, Future):
elif isinstance(x, WrappedKey):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down
13 changes: 10 additions & 3 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.utils import parse_timedelta

from distributed.client import Future
from distributed.utils import wait_for
from distributed.utils import Deadline, wait_for
from distributed.worker import get_client

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,15 +67,22 @@ def release(self, name=None, client=None):
self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name)

async def put(self, name=None, key=None, data=None, client=None, timeout=None):
deadline = Deadline.after(timeout)
if key is not None:
while key not in self.scheduler.tasks:
await asyncio.sleep(0.01)
if deadline.expired:
raise TimeoutError(f"Task {key} unknown to scheudler.")
fjetter marked this conversation as resolved.
Show resolved Hide resolved

record = {"type": "Future", "value": key}
self.future_refcount[name, key] += 1
self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name)
else:
record = {"type": "msgpack", "value": data}
await wait_for(self.queues[name].put(record), timeout=timeout)
await wait_for(self.queues[name].put(record), timeout=deadline.remaining)

def future_release(self, name=None, key=None, client=None):
self.scheduler.client_desires_keys(keys=[key], client=client)
self.future_refcount[name, key] -= 1
if self.future_refcount[name, key] == 0:
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
Expand Down Expand Up @@ -265,7 +272,7 @@ async def _get(self, timeout=None, batch=False):

def process(d):
if d["type"] == "Future":
value = Future(d["value"], self.client, inform=True, state=d["state"])
value = Future(d["value"], self.client, state=d["state"])
if d["state"] == "erred":
value._state.set_error(d["exception"], d["traceback"])
self.client._send_to_scheduler(
Expand Down
10 changes: 4 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,7 @@ def clean(self) -> WorkerState:
)
ws._occupancy_cache = self.occupancy

ws.executing = {
ts.key: duration for ts, duration in self.executing.items() # type: ignore
}
ws.executing = {ts.key: duration for ts, duration in self.executing.items()} # type: ignore
return ws

def __repr__(self) -> str:
Expand Down Expand Up @@ -5595,8 +5593,8 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None:
for k in keys:
ts = self.tasks.get(k)
if ts is None:
# For publish, queues etc.
ts = self.new_task(k, None, "released")
warnings.warn(f"Client desires key {k!r} but key is unknown.")
continue
if ts.who_wants is None:
ts.who_wants = set()
ts.who_wants.add(cs)
Expand Down Expand Up @@ -9345,7 +9343,7 @@ def transition(
def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk = ensure_dict(graph)
dsk: dict = ensure_dict(graph)
if validate:
for k in dsk:
validate_key(k)
Expand Down
97 changes: 15 additions & 82 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
dec,
div,
double,
ensure_no_new_clients,
gen_cluster,
gen_test,
get_cert,
Expand Down Expand Up @@ -2639,24 +2638,32 @@ def test_futures_of_class():
@gen_cluster(client=True)
async def test_futures_of_cancelled_raises(c, s, a, b):
x = c.submit(inc, 1)
await c.cancel([x])

with pytest.raises(CancelledError):
while x.key not in s.tasks:
await asyncio.sleep(0.01)
await c.cancel([x], reason="testreason")

# Note: The scheduler currently doesn't remember the reason but rather
# forgets the task immediately. The reason is currently. only raised if the
# client checks on it. Therefore, we expect an unknown reason and definitely
# not a scheduler disconnected which would otherwise indicate a bug, e.g. an
# AssertionError during transitioning.
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
await x
while x.key in s.tasks:
await asyncio.sleep(0.01)
with pytest.raises(CancelledError):

with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
get_obj = c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False)
gather_obj = c.gather(get_obj)
await gather_obj

with pytest.raises(CancelledError):
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
await c.submit(inc, x)

with pytest.raises(CancelledError):
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
await c.submit(add, 1, y=x)

with pytest.raises(CancelledError):
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
await c.gather(c.map(add, [1], y=x))


Expand Down Expand Up @@ -3027,14 +3034,6 @@ async def test_rebalance_unprepared(c, s, a, b):
s.validate_state()


@gen_cluster(client=True, config=NO_AMM)
async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b):
"""rebalance() raises KeyError if explicitly listed futures disappear"""
f = Future("x", client=c, state="memory")
with pytest.raises(KeyError, match="Could not rebalance keys:"):
await c.rebalance(futures=[f])


@gen_cluster(client=True)
async def test_receive_lost_key(c, s, a, b):
x = c.submit(inc, 1, workers=[a.address])
Expand Down Expand Up @@ -4141,51 +4140,6 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b):
assert z.status == "cancelled"


@gen_cluster()
async def test_serialize_future(s, a, b):
async with (
Client(s.address, asynchronous=True) as c1,
Client(s.address, asynchronous=True) as c2,
):
future = c1.submit(lambda: 1)
result = await future

for ci in (c1, c2):
with ensure_no_new_clients():
with ci.as_current():
future2 = pickle.loads(pickle.dumps(future))
assert future2.client is ci
assert future2.key in ci.futures
result2 = await future2
assert result == result2
with temp_default_client(ci):
future2 = pickle.loads(pickle.dumps(future))


@gen_cluster()
async def test_serialize_future_without_client(s, a, b):
# Do not use a ctx manager to avoid having this being set as a current and/or default client
c1 = await Client(s.address, asynchronous=True, set_as_default=False)
try:
with ensure_no_new_clients():

def do_stuff():
return 1

future = c1.submit(do_stuff)
pickled = pickle.dumps(future)
unpickled_fut = pickle.loads(pickled)

with pytest.raises(RuntimeError):
await unpickled_fut

with c1.as_current():
unpickled_fut_ctx = pickle.loads(pickled)
assert await unpickled_fut_ctx == 1
finally:
await c1.close()


@gen_cluster()
async def test_temp_default_client(s, a, b):
async with (
Expand Down Expand Up @@ -5827,27 +5781,6 @@ async def test_client_with_name(s, a, b):
assert "foo" in text


@gen_cluster(client=True)
async def test_future_defaults_to_default_client(c, s, a, b):
x = c.submit(inc, 1)
await wait(x)

future = Future(x.key)
assert future.client is c


@gen_cluster(client=True)
async def test_future_auto_inform(c, s, a, b):
x = c.submit(inc, 1)
await wait(x)

async with Client(s.address, asynchronous=True) as client:
future = Future(x.key, client)

while future.status != "finished":
await asyncio.sleep(0.01)


def test_client_async_before_loop_starts(cleanup):
with pytest.raises(
RuntimeError,
Expand Down
11 changes: 11 additions & 0 deletions distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,14 @@ async def test_unpickle_without_client(s):
q3 = pickle.loads(pickled)
await q3.put(1)
assert await q3.get() == 1


@gen_cluster(client=True, nthreads=[])
async def test_set_cancelled_future(c, s):
x = c.submit(inc, 1)
await x.cancel()
q = Queue("x")
# FIXME: This is a TimeoutError but pytest doesn't appear to recognize it as
# such
with pytest.raises(Exception, match="unknown to scheudler"):
fjetter marked this conversation as resolved.
Show resolved Hide resolved
await q.put(x, timeout="100ms")
Loading
Loading