Skip to content

Commit

Permalink
Internal cleanup of P2P code (#8907)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Oct 25, 2024
1 parent 928d770 commit aad1178
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 48 deletions.
24 changes: 19 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def add_partition(
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
Expand Down Expand Up @@ -402,6 +403,9 @@ def read(self, path: Path) -> tuple[Any, int]:
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""

def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down Expand Up @@ -475,9 +479,6 @@ def create_new_run(
participating_workers=set(worker_for.values()),
)

def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""

@abc.abstractmethod
def create_run_on_worker(
self,
Expand Down Expand Up @@ -522,7 +523,7 @@ def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P shuffling {id} failed during transfer phase") from e
raise RuntimeError(f"P2P {id} failed during transfer phase") from e


@contextlib.contextmanager
Expand All @@ -538,7 +539,7 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]:
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e
raise RuntimeError(f"P2P {id} failed during unpack phase") from e


def _handle_datetime(buf: Any) -> Any:
Expand All @@ -561,3 +562,16 @@ def _mean_shard_size(shards: Iterable) -> int:
if count == 10:
break
return size // count if count else 0


def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
except Reschedule as e:
raise e
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P {id} failed during barrier phase") from e
13 changes: 9 additions & 4 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from dask.tokenize import tokenize

from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin
from distributed.shuffle._shuffle import shuffle_barrier, shuffle_transfer
from distributed.shuffle._core import (
ShuffleId,
barrier_key,
get_worker_plugin,
p2p_barrier,
)
from distributed.shuffle._shuffle import shuffle_transfer

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -411,8 +416,8 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:

_barrier_key_left = barrier_key(ShuffleId(token_left))
_barrier_key_right = barrier_key(ShuffleId(token_right))
dsk[_barrier_key_left] = (shuffle_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (shuffle_barrier, token_right, transfer_keys_right)
dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (p2p_barrier, token_right, transfer_keys_right)

name = self.name
for part_out in self.parts_out:
Expand Down
5 changes: 3 additions & 2 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._shuffle import barrier_key
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand Down Expand Up @@ -823,7 +824,7 @@ def partial_rechunk(
transfer_keys.append(t.ref())

dsk[_barrier_key] = barrier = Task(
_barrier_key, shuffle_barrier, partial_token, transfer_keys
_barrier_key, p2p_barrier, partial_token, transfer_keys
)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
Expand Down
35 changes: 9 additions & 26 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from dask.utils import is_dataframe_like

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter
from distributed.shuffle._arrow import (
buffers_to_table,
Expand All @@ -49,12 +48,9 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._exceptions import (
DataUnavailable,
P2PConsistencyError,
P2POutOfDiskError,
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
Expand Down Expand Up @@ -106,19 +102,6 @@ def shuffle_unpack(
)


def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
except Reschedule as e:
raise e
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e


def rearrange_by_column_p2p(
df: DataFrame,
column: str,
Expand Down Expand Up @@ -306,7 +289,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
dsk[t.key] = t
transfer_keys.append(t.ref())

barrier = Task(_barrier_key, shuffle_barrier, token, transfer_keys)
barrier = Task(_barrier_key, p2p_barrier, token, transfer_keys)
dsk[barrier.key] = barrier

name = self.name
Expand Down Expand Up @@ -570,6 +553,12 @@ def read(self, path: Path) -> tuple[pa.Table, int]:
def deserialize(self, buffer: Any) -> Any:
return deserialize_table(buffer)

def validate_data(self, data: pd.DataFrame) -> None:
if not is_dataframe_like(data):
raise TypeError(f"Expected {data=} to be a DataFrame, got {type(data)}.")
if set(data.columns) != set(self.meta.columns):
raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.")


@dataclass(frozen=True)
class DataFrameShuffleSpec(ShuffleSpec[int]):
Expand All @@ -586,12 +575,6 @@ def output_partitions(self) -> Generator[int]:
def pick_worker(self, partition: int, workers: Sequence[str]) -> str:
return _get_worker_for_range_sharding(self.npartitions, partition, workers)

def validate_data(self, data: pd.DataFrame) -> None:
if not is_dataframe_like(data):
raise TypeError(f"Expected {data=} to be a DataFrame, got {type(data)}.")
if set(data.columns) != set(self.meta.columns):
raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.")

def create_run_on_worker(
self,
run_id: int,
Expand Down
8 changes: 0 additions & 8 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def add_partition(
spec: ShuffleSpec,
**kwargs: Any,
) -> int:
spec.validate_data(data)
shuffle_run = self.get_or_create_shuffle(spec)
return shuffle_run.add_partition(
data=data,
Expand Down Expand Up @@ -387,13 +386,6 @@ async def _get_shuffle_run(
shuffle_id=shuffle_id, run_id=run_id
)

async def _get_or_create_shuffle(
self,
spec: ShuffleSpec,
key: Key,
) -> ShuffleRun:
return await self.shuffle_runs.get_or_create(spec=spec, key=key)

async def teardown(self, worker: Worker) -> None:
assert not self.closed

Expand Down
6 changes: 3 additions & 3 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,7 +2546,7 @@ def make_partition(i):

with raises_with_cause(
RuntimeError,
r"(shuffling \w*|shuffle_barrier) failed",
r"P2P \w* failed",
pa.ArrowTypeError,
"incompatible types",
):
Expand Down Expand Up @@ -2744,7 +2744,7 @@ async def test_flaky_connect_fails_without_retry(c, s, a, b):
with mock.patch.object(a, "rpc", rpc):
with raises_with_cause(
expected_exception=RuntimeError,
match="P2P shuffling.*transfer",
match="P2P.*transfer",
expected_cause=OSError,
match_cause=None,
):
Expand Down Expand Up @@ -2899,7 +2899,7 @@ def data_gen():

with raises_with_cause(
RuntimeError,
r"shuffling \w* failed",
r"P2P \w* failed",
ValueError,
"meta",
):
Expand Down

0 comments on commit aad1178

Please sign in to comment.