Skip to content

Commit

Permalink
Reduce memory footprint of culling P2P rechunking (#8845)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Aug 30, 2024
1 parent 4aeed40 commit e19c630
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,28 +370,31 @@ def cull(
indices_to_keep = self._keys_to_indices(keys)
_old_to_new = old_to_new(self.chunks_input, self.chunks)

culled_deps: defaultdict[Key, set[Key]] = defaultdict(set)
for nindex in indices_to_keep:
old_indices_per_axis = []
keepmap[nindex] = True
for index, new_axis in zip(nindex, _old_to_new):
old_indices_per_axis.append(
[old_chunk_index for old_chunk_index, _ in new_axis[index]]
)
for old_nindex in product(*old_indices_per_axis):
culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex)
for ndindex in indices_to_keep:
keepmap[ndindex] = True

# Protect against mutations later on with frozenset
frozen_deps = {
output_task: frozenset(input_tasks)
for output_task, input_tasks in culled_deps.items()
}
culled_deps = {}
# Identify the individual partial rechunks
for ndpartial in _split_partials(_old_to_new):
# Cull partials for which we do not keep any output tasks
if not np.any(keepmap[ndpartial.new]):
continue

# Within partials, we have all-to-all communication.
# Thus, all output tasks share the same input tasks.
deps = frozenset(
(self.name_input,) + ndindex
for ndindex in _ndindices_of_slice(ndpartial.old)
)

for ndindex in _ndindices_of_slice(ndpartial.new):
culled_deps[(self.name,) + ndindex] = deps

if np.array_equal(keepmap, self.keepmap):
return self, frozen_deps
return self, culled_deps
else:
culled_layer = self._cull(keepmap)
return culled_layer, frozen_deps
return culled_layer, culled_deps

def _construct_graph(self) -> _T_LowLevelGraph:
import numpy as np
Expand Down Expand Up @@ -695,14 +698,12 @@ def _slice_new_chunks_into_partials(
return tuple(sliced_axes)


def _partial_ndindex(ndslice: NDSlice) -> np.ndindex:
import numpy as np

return np.ndindex(tuple(slice.stop - slice.start for slice in ndslice))
def _ndindices_of_slice(ndslice: NDSlice) -> Iterator[NDIndex]:
return product(*(range(slc.start, slc.stop) for slc in ndslice))


def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
return tuple(index + offset for index, offset in zip(partial_index, partial_offset))
def _partial_index(global_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
return tuple(index - offset for index, offset in zip(global_index, partial_offset))


def partial_concatenate(
Expand Down Expand Up @@ -802,8 +803,8 @@ def partial_rechunk(
)

transfer_keys = []
for partial_index in _partial_ndindex(ndpartial.old):
global_index = _global_index(partial_index, old_partial_offset)
for global_index in _ndindices_of_slice(ndpartial.old):
partial_index = _partial_index(global_index, old_partial_offset)

input_key = (input_name,) + global_index

Expand All @@ -822,8 +823,8 @@ def partial_rechunk(
dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
for partial_index in _partial_ndindex(ndpartial.new):
global_index = _global_index(partial_index, new_partial_offset)
for global_index in _ndindices_of_slice(ndpartial.new):
partial_index = _partial_index(global_index, new_partial_offset)
if keepmap[global_index]:
dsk[(unpack_group,) + global_index] = (
rechunk_unpack,
Expand Down

0 comments on commit e19c630

Please sign in to comment.