Skip to content

Commit

Permalink
Register arrow import/export dispatch to make p2p shuffle work (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored Jun 24, 2024
1 parent 7679bf9 commit fcbddf2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
34 changes: 34 additions & 0 deletions dask_geopandas/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from packaging.version import Version

import dask
from dask import config

# Check if dask-dataframe is using dask-expr (default of None means True as well)
Expand Down Expand Up @@ -84,3 +85,36 @@ def get_pyarrow_schema_geopandas(obj):
for col in obj.columns[obj.dtypes == "geometry"]:
df[col] = obj[col].to_wkb()
return pa.Schema.from_pandas(df)


if Version(dask.__version__) >= Version("2023.6.1"):
from dask.dataframe.dispatch import (
from_pyarrow_table_dispatch,
to_pyarrow_table_dispatch,
)

@to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,))
def get_pyarrow_table_from_geopandas(obj, **kwargs):
# `kwargs` must be supported by `pyarrow.Table.from_pandas`
import pyarrow as pa

if Version(geopandas.__version__).major < 1:
return pa.Table.from_pandas(obj.to_wkb(), **kwargs)
else:
# TODO handle kwargs?
return pa.table(obj.to_arrow())

@from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,))
def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs):
# `kwargs` must be supported by `pyarrow.Table.to_pandas`
if Version(geopandas.__version__).major < 1:
df = table.to_pandas(**kwargs)

for col in meta.columns[meta.dtypes == "geometry"]:
df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs)

return df

else:
# TODO handle kwargs?
return geopandas.GeoDataFrame.from_arrow(table)
38 changes: 38 additions & 0 deletions dask_geopandas/tests/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from packaging.version import Version

import geopandas

import dask_geopandas

import pytest
from geopandas.testing import assert_geodataframe_equal

distributed = pytest.importorskip("distributed")


from distributed import Client, LocalCluster


@pytest.mark.skipif(
Version(distributed.__version__) < Version("2024.6.0"),
reason="distributed < 2024.6 has a wrong assertion",
# https://github.com/dask/distributed/pull/8667
)
@pytest.mark.skipif(
Version(distributed.__version__) < Version("0.13"),
reason="geopandas < 0.13 does not implement sorting geometries",
)
def test_spatial_shuffle(naturalearth_cities):
df_points = geopandas.read_file(naturalearth_cities)

with LocalCluster(n_workers=1) as cluster:
with Client(cluster):
ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4)

ddf_result = ddf_points.spatial_shuffle(
by="hilbert", calculate_partitions=False
)
result = ddf_result.compute()

expected = df_points.sort_values("geometry").reset_index(drop=True)
assert_geodataframe_equal(result.reset_index(drop=True), expected)

0 comments on commit fcbddf2

Please sign in to comment.