Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: support callable as stats in zonal_stats #55

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
from typing import Any, Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -921,10 +921,10 @@ def to_geodataframe(

def zonal_stats(
self,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: Hashable = "geometry",
index: bool = None,
method: str = "rasterize",
Expand All @@ -934,37 +934,43 @@ def zonal_stats(
):
"""Extract the values from a dataset indexed by a set of geometries

The CRS of the raster and that of polygons need to be equal.
The CRS of the raster and that of geometry need to be equal.
Xvec does not verify their equality.

Parameters
----------
polygons : Sequence[shapely.Geometry]
geometry : Sequence[shapely.Geometry]
An arrray-like (1-D) of shapely geometries, like a numpy array or
:class:`geopandas.GeoSeries`.
:class:`geopandas.GeoSeries`. Polygon and LineString geometry types are
supported.
x_coords : Hashable
name of the coordinates containing ``x`` coordinates (i.e. the first value
in the coordinate pair encoding the vertex of the polygon)
y_coords : Hashable
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : string
Spatial aggregation statistic method, by default "mean". It supports the
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
stats : string | Callable
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as :class:`xarray.DataArray` or
:class:`xarray.DataArrayGroupBy` methods like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
name : Hashable, optional
Name of the dimension that will hold the ``polygons``, by default "geometry"
Name of the dimension that will hold the ``geometry``, by default "geometry"
index : bool, optional
If `polygons` is a GeoSeries, ``index=True`` will attach its index as another
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
coordinate to the geometry dimension in the resulting object. If
``index=None``, the index will be stored if the `polygons.index` is a named
``index=None``, the index will be stored if the `geometry.index` is a named
or non-default index. If ``index=False``, it will never be stored. This is
useful as an attribute link between the resulting array and the GeoPandas
object from which the polygons are sourced.
object from which the geometry is sourced.
method : str, optional
The method of data extraction. The default is ``"rasterize"``, which uses
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
of information in case of small polygons. Other option is ``"iterate"``, which
iterates over polygons and uses :func:`rasterio.features.geometry_mask`.
of information in case of small polygons or lines. Other option is ``"iterate"``, which
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Expand All @@ -975,22 +981,21 @@ def zonal_stats(
only if ``method="iterate"``.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(e.g., ``Dataset.mean(**kwargs)``).
(e.g., ``Dataset.quantile(**kwargs)``).

Returns
-------
Dataset
Dataset or DataArray
A subset of the original object with N-1 dimensions indexed by
the the GeometryIndex.
the :class:`GeometryIndex` of ``geometry``.

"""
# TODO: allow multiple stats at the same time (concat along a new axis),
# TODO: possibly as a list of tuples to include names?
# TODO: allow callable in stat (via .reduce())
if method == "rasterize":
result = _zonal_stats_rasterize(
self,
polygons=polygons,
geometry=geometry,
x_coords=x_coords,
y_coords=y_coords,
stats=stats,
Expand All @@ -1001,7 +1006,7 @@ def zonal_stats(
elif method == "iterate":
result = _zonal_stats_iterative(
self,
polygons=polygons,
geometry=geometry,
x_coords=x_coords,
y_coords=y_coords,
stats=stats,
Expand All @@ -1017,15 +1022,15 @@ def zonal_stats(
)

# save the index as a data variable
if isinstance(polygons, pd.Series):
if isinstance(geometry, pd.Series):
if index is None:
if polygons.index.name is not None or not polygons.index.equals(
pd.RangeIndex(0, len(polygons))
if geometry.index.name is not None or not geometry.index.equals(
pd.RangeIndex(0, len(geometry))
):
index = True
if index:
index_name = polygons.index.name if polygons.index.name else "index"
result = result.assign_coords({index_name: (name, polygons.index)})
index_name = geometry.index.name if geometry.index.name else "index"
result = result.assign_coords({index_name: (name, geometry.index)})

# standardize the shape - each method comes with a different one
return result.transpose(
Expand Down
26 changes: 26 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,29 @@ def test_crs(method):

actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual, expected)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_callable(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
ds_agg = ds.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats=np.nanstd
)
ds_std = ds.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(ds_agg, ds_std)

da_agg = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats=np.nanstd,
n_jobs=1,
)
da_std = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(da_agg, da_std)
62 changes: 38 additions & 24 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gc
from collections.abc import Hashable, Sequence
from typing import Callable

import numpy as np
import shapely
Expand All @@ -10,16 +11,16 @@

def _zonal_stats_rasterize(
acc,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: str = "geometry",
all_touched: bool = False,
**kwargs,
):
try:
import rasterio # noqa: F401
import rasterio
import rioxarray # noqa: F401
except ImportError as err:
raise ImportError(
Expand All @@ -28,15 +29,15 @@ def _zonal_stats_rasterize(
"'pip install rioxarray'."
) from err

if hasattr(polygons, "crs"):
crs = polygons.crs
if hasattr(geometry, "crs"):
crs = geometry.crs
else:
crs = None

transform = acc._obj.rio.transform()

labels = rasterio.features.rasterize(
zip(polygons, range(len(polygons))),
zip(geometry, range(len(geometry))),
out_shape=(
acc._obj[y_coords].shape[0],
acc._obj[x_coords].shape[0],
Expand All @@ -46,10 +47,13 @@ def _zonal_stats_rasterize(
all_touched=all_touched,
)
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
agg = getattr(groups, stats)(**kwargs)
if isinstance(stats, str):
agg = getattr(groups, stats)(**kwargs)
else:
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
vec_cube = (
agg.reindex(group=range(len(polygons)))
.assign_coords(group=polygons)
agg.reindex(group=range(len(geometry)))
.assign_coords(group=geometry)
.rename(group=name)
).xvec.set_geom_indexes(name, crs=crs)

Expand All @@ -61,23 +65,23 @@ def _zonal_stats_rasterize(

def _zonal_stats_iterative(
acc,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
**kwargs,
):
"""Extract the values from a dataset indexed by a set of geometries

The CRS of the raster and that of polygons need to be equal.
The CRS of the raster and that of geometry need to be equal.
Xvec does not verify their equality.

Parameters
----------
polygons : Sequence[shapely.Geometry]
geometry : Sequence[shapely.Geometry]
An arrray-like (1-D) of shapely geometries, like a numpy array or
:class:`geopandas.GeoSeries`.
x_coords : Hashable
Expand All @@ -87,10 +91,14 @@ def _zonal_stats_iterative(
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : Hashable
Spatial aggregation statistic method, by default "mean". It supports the
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as DataArray or DataArrayGroupBy like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`,
methods are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
name : Hashable, optional
Name of the dimension that will hold the ``polygons``, by default "geometry"
Name of the dimension that will hold the ``geometry``, by default "geometry"
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Expand Down Expand Up @@ -140,14 +148,14 @@ def _zonal_stats_iterative(
all_touched=all_touched,
**kwargs,
)
for geom in polygons
for geom in geometry
)
if hasattr(polygons, "crs"):
crs = polygons.crs
if hasattr(geometry, "crs"):
crs = geometry.crs
else:
crs = None
vec_cube = xr.concat(
zonal, dim=xr.DataArray(polygons, name=name, dims=name)
zonal, dim=xr.DataArray(geometry, name=name, dims=name)
).xvec.set_geom_indexes(name, crs=crs)
gc.collect()

Expand All @@ -160,7 +168,7 @@ def _agg_geom(
trans,
x_coords: str = None,
y_coords: str = None,
stats: str = "mean",
stats: str | Callable = "mean",
all_touched=False,
**kwargs,
):
Expand Down Expand Up @@ -207,9 +215,15 @@ def _agg_geom(
invert=True,
all_touched=all_touched,
)
result = getattr(
acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))), stats
)(dim=(y_coords, x_coords), keep_attrs=True, **kwargs)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if isinstance(stats, str):
result = getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
else:
result = masked.reduce(
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)

del mask
gc.collect()
Expand Down