Skip to content

Commit

Permalink
Explicit support for curvefit (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Jun 3, 2022
1 parent f24c1c3 commit 1fbc074
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 3 deletions.
47 changes: 47 additions & 0 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
List,
Mapping,
MutableMapping,
Sequence,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -1059,6 +1060,52 @@ def _get_all_cell_measures(self):

return self._all_cell_measures

def curvefit(
self,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Hashable | Iterable[Hashable] = None,
skipna: bool = True,
p0: dict[str, Any] = None,
bounds: dict[str, Any] = None,
param_names: Sequence[str] = None,
kwargs: dict[str, Any] = None,
):

if coords is not None:
if isinstance(coords, str):
coords = (coords,)
coords = [
apply_mapper( # type: ignore
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore
)[
0
] # type: ignore
for v in coords
]
if reduce_dims is not None:
if isinstance(reduce_dims, Hashable):
reduce_dims: Iterable[Hashable] = (reduce_dims,) # type: ignore
reduce_dims = [
apply_mapper( # type: ignore
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore
)[
0
] # type: ignore
for v in reduce_dims # type: ignore
]

return self._obj.curvefit(
coords=coords,
func=func,
reduce_dims=reduce_dims,
skipna=skipna,
p0=p0,
bounds=bounds,
param_names=param_names,
kwargs=kwargs,
)

def _process_signature(
self,
func: Callable,
Expand Down
5 changes: 3 additions & 2 deletions cf_xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def LooseVersion(vstring):
return version.LooseVersion(vstring)


has_pint, requires_pint = _importorskip("pint")
has_shapely, requires_shapely = _importorskip("shapely")
has_cftime, requires_cftime = _importorskip("cftime")
has_scipy, requires_scipy = _importorskip("scipy")
has_shapely, requires_shapely = _importorskip("shapely")
has_pint, requires_pint = _importorskip("pint")
25 changes: 24 additions & 1 deletion cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
romsds,
vert,
)
from . import raise_if_dask_computes, requires_cftime, requires_pint
from . import raise_if_dask_computes, requires_cftime, requires_pint, requires_scipy

mpl.use("Agg")

Expand Down Expand Up @@ -1619,3 +1619,26 @@ def test_cf_role():

dsg.foo.cf.plot(x="profile_id")
dsg.foo.cf.plot(x="trajectory_id")


@requires_scipy
def test_curvefit():
from cf_xarray.datasets import airds

def line(time, slope):
t = (time - time[0]).astype(float)
return slope * t

actual = airds.air.cf.isel(lat=4, lon=5).curvefit(coords=("time",), func=line)
expected = airds.air.cf.isel(lat=4, lon=5).cf.curvefit(coords="T", func=line)
assert_identical(expected, actual)

def plane(coords, slopex, slopey):
x, y = coords
return slopex * (x - x.mean()) + slopey * (y - y.mean())

actual = airds.air.isel(time=0).curvefit(coords=("lat", "lon"), func=plane)
expected = airds.air.isel(time=0).cf.curvefit(
coords=("latitude", "longitude"), func=plane
)
assert_identical(expected, actual)
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ dependencies:
- pandas
- pint
- pooch
- scipy
- shapely
- xarray
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ known_first_party = cf_xarray
known_third_party = dask,matplotlib,numpy,pandas,pint,pytest,setuptools,sphinx_autosummary_accessors,xarray

# Most of the numerical computing stack doesn't have type annotations yet.
[mypy]
allow_redefinition = True
[mypy-affine.*]
ignore_missing_imports = True
[mypy-bottleneck.*]
Expand Down

0 comments on commit 1fbc074

Please sign in to comment.