diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index e7c3d88c..ac24fc66 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -15,6 +15,7 @@ List, Mapping, MutableMapping, + Sequence, TypeVar, Union, cast, @@ -1059,6 +1060,52 @@ def _get_all_cell_measures(self): return self._all_cell_measures + def curvefit( + self, + coords: str | DataArray | Iterable[str | DataArray], + func: Callable[..., Any], + reduce_dims: Hashable | Iterable[Hashable] = None, + skipna: bool = True, + p0: dict[str, Any] = None, + bounds: dict[str, Any] = None, + param_names: Sequence[str] = None, + kwargs: dict[str, Any] = None, + ): + + if coords is not None: + if isinstance(coords, str): + coords = (coords,) + coords = [ + apply_mapper( # type: ignore + [_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore + )[ + 0 + ] # type: ignore + for v in coords + ] + if reduce_dims is not None: + if isinstance(reduce_dims, Hashable): + reduce_dims: Iterable[Hashable] = (reduce_dims,) # type: ignore + reduce_dims = [ + apply_mapper( # type: ignore + [_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore + )[ + 0 + ] # type: ignore + for v in reduce_dims # type: ignore + ] + + return self._obj.curvefit( + coords=coords, + func=func, + reduce_dims=reduce_dims, + skipna=skipna, + p0=p0, + bounds=bounds, + param_names=param_names, + kwargs=kwargs, + ) + def _process_signature( self, func: Callable, diff --git a/cf_xarray/tests/__init__.py b/cf_xarray/tests/__init__.py index 6b8d7683..eca25d79 100644 --- a/cf_xarray/tests/__init__.py +++ b/cf_xarray/tests/__init__.py @@ -63,6 +63,7 @@ def LooseVersion(vstring): return version.LooseVersion(vstring) -has_pint, requires_pint = _importorskip("pint") -has_shapely, requires_shapely = _importorskip("shapely") has_cftime, requires_cftime = _importorskip("cftime") +has_scipy, requires_scipy = _importorskip("scipy") +has_shapely, requires_shapely = _importorskip("shapely") +has_pint, requires_pint = _importorskip("pint") diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index caaa586e..24811f2b 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -30,7 +30,7 @@ romsds, vert, ) -from . import raise_if_dask_computes, requires_cftime, requires_pint +from . import raise_if_dask_computes, requires_cftime, requires_pint, requires_scipy mpl.use("Agg") @@ -1619,3 +1619,26 @@ def test_cf_role(): dsg.foo.cf.plot(x="profile_id") dsg.foo.cf.plot(x="trajectory_id") + + +@requires_scipy +def test_curvefit(): + from cf_xarray.datasets import airds + + def line(time, slope): + t = (time - time[0]).astype(float) + return slope * t + + actual = airds.air.cf.isel(lat=4, lon=5).curvefit(coords=("time",), func=line) + expected = airds.air.cf.isel(lat=4, lon=5).cf.curvefit(coords="T", func=line) + assert_identical(expected, actual) + + def plane(coords, slopex, slopey): + x, y = coords + return slopex * (x - x.mean()) + slopey * (y - y.mean()) + + actual = airds.air.isel(time=0).curvefit(coords=("lat", "lon"), func=plane) + expected = airds.air.isel(time=0).cf.curvefit( + coords=("latitude", "longitude"), func=plane + ) + assert_identical(expected, actual) diff --git a/ci/environment.yml b/ci/environment.yml index 21362624..5d85c0a5 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -11,5 +11,6 @@ dependencies: - pandas - pint - pooch + - scipy - shapely - xarray diff --git a/setup.cfg b/setup.cfg index 802c2c35..fdcf23e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,8 @@ known_first_party = cf_xarray known_third_party = dask,matplotlib,numpy,pandas,pint,pytest,setuptools,sphinx_autosummary_accessors,xarray # Most of the numerical computing stack doesn't have type annotations yet. +[mypy] +allow_redefinition = True [mypy-affine.*] ignore_missing_imports = True [mypy-bottleneck.*]