Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cmethods.adjust's type annotations #155

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
run: python -m pip install ".[dev,test]"

- name: Generate coverage report
run: pytest --cov --cov-report=xml
run: pytest --retries 1 --cov --cov-report=xml

- name: Upload coverage to Codecov
uses: codecov/codecov-action@1e68e06f1dbfde0e4cefc87efeba9e4643565303 #v5.1.2
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ jobs:
python-version: ${{ inputs.python-version }}

- name: Install dependencies
run: |
python -m pip install --user --upgrade pip
run: python -m pip install --user --upgrade pip

- name: Install package
run: python -m pip install --user ".[dev,test]"

- name: Run unit tests
run: pytest -vv tests
run: pytest -vv --retries 1 tests
4 changes: 2 additions & 2 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
with:
os: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
Expand All @@ -65,7 +65,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
with:
os: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.0
rev: v0.9.3
hooks:
- id: ruff
args:
Expand All @@ -23,12 +23,12 @@ repos:
# - --install-types
# - --non-interactive
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.0
hooks:
- id: codespell
additional_dependencies: [tomli]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
# all available hooks can be found here: https://github.com/pre-commit/pre-commit-hooks/blob/main/.pre-commit-hooks.yaml
- id: check-yaml
Expand Down Expand Up @@ -72,7 +72,7 @@ repos:
- id: isort
args: [--profile=black]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.10
rev: 1.8.2
hooks:
- id: bandit
exclude: "^tests/.*|examples/.*"
Expand Down
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ install:
test:
$(PYTHON) -m pytest $(PYTEST_OPTS) $(TESTS)

.PHONY: tests
.PHONY: test
tests: test

## retest Rerun tests that failed before
##
.PHONY: retest
retest:
$(PYTHON) -m pytest $(PYTEST_OPTS) --lf $(TESTS)


## wip Run tests marked as wip
##
.PHONY: wip
Expand Down
12 changes: 6 additions & 6 deletions cmethods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def cli(**kwargs) -> None:
datefmt="%Y/%m/%d %H:%M:%S",
level=logging.INFO,
)

logging.info("Loading data sets ...")
log = logging.getLogger(__name__)
log.info("Loading data sets ...")
try:
for key, message in zip(
("obs", "simh", "simp"),
Expand All @@ -194,15 +194,15 @@ def cli(**kwargs) -> None:
)
kwargs[key] = kwargs[key][kwargs["variable"]]
except (TypeError, KeyError) as exc:
logging.error(exc)
log.error(exc)
sys.exit(1)

logging.info("Data sets loaded ...")
log.info("Data sets loaded ...")
kwargs["n_quantiles"] = kwargs["quantiles"]
del kwargs["quantiles"]

logging.info("Applying %s ..." % kwargs["method"])
log.info("Applying %s ...", kwargs["method"])
result = adjust(**kwargs)

logging.info("Saving result to %s ..." % kwargs["output"])
log.info("Saving result to %s ...", kwargs["output"])
result.to_netcdf(kwargs["output"])
32 changes: 17 additions & 15 deletions cmethods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cmethods.scaling import linear_scaling as __linear_scaling
from cmethods.scaling import variance_scaling as __variance_scaling
from cmethods.static import SCALING_METHODS
from cmethods.utils import UnknownMethodError, check_xr_types
from cmethods.utils import UnknownMethodError, ensure_xr_dataarray

if TYPE_CHECKING:
from cmethods.types import XRData
Expand All @@ -37,16 +37,16 @@

def apply_ufunc(
method: str,
obs: XRData,
simh: XRData,
simp: XRData,
obs: xr.xarray.core.dataarray.DataArray,
simh: xr.xarray.core.dataarray.DataArray,
simp: xr.xarray.core.dataarray.DataArray,
**kwargs: dict,
) -> XRData:
) -> xr.xarray.core.dataarray.DataArray:
"""
Internal function used to apply the bias correction technique to the
passed input data.
"""
check_xr_types(obs=obs, simh=simh, simp=simp)
ensure_xr_dataarray(obs=obs, simh=simh, simp=simp)
if method not in __METHODS_FUNC__:
raise UnknownMethodError(method, __METHODS_FUNC__.keys())

Expand Down Expand Up @@ -96,11 +96,11 @@ def apply_ufunc(

def adjust(
method: str,
obs: XRData,
simh: XRData,
simp: XRData,
obs: xr.xarray.core.dataarray.DataArray,
simh: xr.xarray.core.dataarray.DataArray,
simp: xr.xarray.core.dataarray.DataArray,
**kwargs,
) -> XRData:
) -> xr.xarray.core.dataarray.DataArray | xr.xarray.core.dataarray.Dataset:
"""
Function to apply a bias correction technique on single and multidimensional
data sets. For more information please refer to the method specific
Expand All @@ -119,19 +119,19 @@ def adjust(
:param method: Technique to apply
:type method: str
:param obs: The reference/observational data set
:type obs: XRData
:type obs: xr.xarray.core.dataarray.DataArray
:param simh: The modeled data of the control period
:type simh: XRData
:type simh: xr.xarray.core.dataarray.DataArray
:param simp: The modeled data of the period to adjust
:type simp: XRData
:type simp: xr.xarray.core.dataarray.DataArray
:param kwargs: Any other method-specific parameter (like
``n_quantiles`` and ``kind``)
:type kwargs: dict
:return: The bias corrected/adjusted data set
:rtype: XRData
:rtype: xr.xarray.core.dataarray.DataArray | xr.xarray.core.dataarray.Dataset
"""
kwargs["adjust_called"] = True
check_xr_types(obs=obs, simh=simh, simp=simp)
ensure_xr_dataarray(obs=obs, simh=simh, simp=simp)

if method == "detrended_quantile_mapping": # noqa: PLR2004
raise ValueError(
Expand Down Expand Up @@ -169,6 +169,8 @@ def adjust(
obs_group = group["obs"]
simh_group = group["simh"]
simp_group = group["simp"]
else:
raise ValueError("'group' must be a string or a dict!")

del kwargs["group"]

Expand Down
8 changes: 4 additions & 4 deletions cmethods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def check_adjust_called(
)


def check_xr_types(obs: XRData, simh: XRData, simp: XRData) -> None:
def ensure_xr_dataarray(obs: XRData, simh: XRData, simp: XRData) -> None:
"""
Checks if the parameters are in the correct type. **only used internally**
"""
phrase: str = "must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray"
phrase: str = "must be type 'xarray.core.dataarray.DataArray'."

if not isinstance(obs, XRData_t):
raise TypeError(f"'obs' {phrase}")
Expand All @@ -73,7 +73,7 @@ def check_np_types(
"""
Checks if the parameters are in the correct type. **only used internally**
"""
phrase: str = "must be type list, np.ndarray or np.generic"
phrase: str = "must be type list, np.ndarray, or np.generic"

if not isinstance(obs, NPData_t):
raise TypeError(f"'obs' {phrase}")
Expand Down Expand Up @@ -246,8 +246,8 @@ def get_adjusted_scaling_factor(
"UnknownMethodError",
"check_adjust_called",
"check_np_types",
"check_xr_types",
"ensure_dividable",
"ensure_xr_dataarray",
"get_adjusted_scaling_factor",
"get_cdf",
"get_inverse_of_cdf",
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ keywords = [
classifiers = [
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities",
"Topic :: Scientific/Engineering",
Expand Down Expand Up @@ -99,20 +102,21 @@ dev = [
# linting
"pylint",
"flake8",
"ruff==0.3.5",
"ruff",
# typing
"mypy",
]
test = [
# testing
"pytest",
"pytest-cov",
"pytest-retry",
"zarr",
"dask[distributed]",
"scikit-learn",
"scipy",
]
examples = ["click", "matplotlib"]
examples = ["matplotlib"]

[tool.codespell]
check-filenames = true
Expand Down
4 changes: 2 additions & 2 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_not_implemented_errors(
def test_adjust_failing_dqm(datasets: dict) -> None:
with pytest.raises(
ValueError,
match="This function is not available for detrended quantile mapping. "
match=r"This function is not available for detrended quantile mapping. "
"Please use cmethods.CMethods.detrended_quantile_mapping",
):
adjust(
Expand All @@ -118,7 +118,7 @@ def test_adjust_failing_dqm(datasets: dict) -> None:
def test_adjust_failing_no_group_for_distribution(datasets: dict) -> None:
with pytest.raises(
ValueError,
match="Can't use group for distribution based methods.",
match=r"Can't use group for distribution based methods.",
):
adjust(
method="quantile_mapping",
Expand Down
14 changes: 7 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from cmethods.static import MAX_SCALING_FACTOR
from cmethods.utils import (
check_np_types,
check_xr_types,
ensure_dividable,
ensure_xr_dataarray,
get_adjusted_scaling_factor,
get_pdf,
nan_or_equal,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_xr_type_check() -> None:
correct. No error should occur.
"""
ds: xr.core.dataarray.Dataset = xr.core.dataarray.Dataset()
check_xr_types(obs=ds, simh=ds, simp=ds)
ensure_xr_dataarray(obs=ds, simh=ds, simp=ds)


def test_type_check_failing() -> None:
Expand All @@ -142,7 +142,7 @@ def test_type_check_failing() -> None:
have the correct type.
"""

phrase: str = "must be type list, np.ndarray or np.generic"
phrase: str = "must be type list, np.ndarray, or np.generic"
with pytest.raises(TypeError, match=f"'obs' {phrase}"):
check_np_types(obs=1, simh=[], simp=[])

Expand Down Expand Up @@ -177,7 +177,7 @@ def test_detrended_quantile_mapping_type_check_simp_failing(datasets: dict) -> N
"""n_quantiles must by type int"""
with pytest.raises(
TypeError,
match="'simp' must be type xarray.core.dataarray.DataArray",
match=r"'simp' must be type xarray.core.dataarray.DataArray",
):
detrended_quantile_mapping( # type: ignore[attr-defined]
obs=datasets["+"]["obsh"][:, 0, 0],
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_adjust_type_checking_failing() -> None:
)
with pytest.raises(
TypeError,
match="'obs' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
match=r"'obs' must be type 'xarray.core.dataarray.DataArray'.",
):
adjust(
method="linear_scaling",
Expand All @@ -233,7 +233,7 @@ def test_adjust_type_checking_failing() -> None:
)
with pytest.raises(
TypeError,
match="'simh' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
match=r"'simh' must be type 'xarray.core.dataarray.DataArray'.",
):
adjust(
method="linear_scaling",
Expand All @@ -245,7 +245,7 @@ def test_adjust_type_checking_failing() -> None:

with pytest.raises(
TypeError,
match="'simp' must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray",
match=r"'simp' must be type 'xarray.core.dataarray.DataArray'.",
):
adjust(
method="linear_scaling",
Expand Down