Skip to content

Commit

Permalink
fix a bug for prcp_norm for streamflow
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 5, 2024
1 parent 5b4da2b commit baac385
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions torchhydro/datasets/data_scalers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:17:44
LastEditTime: 2024-07-10 19:41:48
LastEditTime: 2024-11-05 09:21:24
LastEditors: Wenyu Ouyang
Description: normalize the data
FilePath: /torchhydro/torchhydro/datasets/data_scalers.py
FilePath: \torchhydro\torchhydro\datasets\data_scalers.py
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
"""

Expand All @@ -13,6 +13,7 @@
import os
import pickle as pkl
import shutil
from typing import Optional
import pint_xarray # noqa: F401
import xarray as xr
import numpy as np
Expand Down Expand Up @@ -53,11 +54,11 @@ class ScalerHub(object):

def __init__(
self,
target_vars: np.array,
relevant_vars: np.array,
constant_vars: np.array = None,
data_cfgs: dict = None,
is_tra_val_te: str = None,
target_vars: np.ndarray,
relevant_vars: np.ndarray,
constant_vars: Optional[np.ndarray] = None,
data_cfgs: Optional[dict] = None,
is_tra_val_te: Optional[str] = None,
data_source: object = None,
**kwargs,
):
Expand Down Expand Up @@ -199,12 +200,12 @@ def __init__(
class DapengScaler(object):
def __init__(
self,
target_vars: np.array,
relevant_vars: np.array,
constant_vars: np.array,
target_vars: np.ndarray,
relevant_vars: np.ndarray,
constant_vars: np.ndarray,
data_cfgs: dict,
is_tra_val_te: str,
other_vars: dict = None,
other_vars: Optional[dict] = None,
prcp_norm_cols=None,
gamma_norm_cols=None,
pbm_norm=False,
Expand Down Expand Up @@ -288,12 +289,19 @@ def __init__(

@property
def mean_prcp(self):
return (
self.data_source.read_mean_prcp(self.t_s_dict["sites_id"])
.to_array()
.to_numpy()
.T # TODO: check why T is needed
"""This property is used to be divided by streamflow to normalize streamflow,
hence, its unit is same as streamflow
Returns
-------
np.ndarray
mean_prcp with the same unit as streamflow
"""
final_unit = self.data_target.attrs["units"]["streamflow"]
mean_prcp = self.data_source.read_mean_prcp(
self.t_s_dict["sites_id"], unit=final_unit
)
return mean_prcp.to_array().transpose("basin", "variable").to_numpy()

def inverse_transform(self, target_values):
"""
Expand Down

0 comments on commit baac385

Please sign in to comment.