Skip to content

Commit

Permalink
Merge remote-tracking branch 'downstream/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Apr 21, 2024
2 parents ccac84b + bddbe0f commit 62dd813
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 5 deletions.
83 changes: 83 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Author: silencesoup [email protected]
Date: 2024-04-20 11:30:16
LastEditors: silencesoup [email protected]
LastEditTime: 2024-04-20 11:41:05
FilePath: /torchhydro/tests/test_model.py
"""

import torch
from torchhydro.models import seq2seq
from torchhydro.models import model_dict_function


def test_model():
# put your model param here, it's just an example for seq2seq model
model_configs = {
"Seq2Seq": {
"input_size": 16,
"output_size": 1,
"hidden_size": 256,
"forecast_length": 168,
"cnn_size": 120,
"model_mode": "single",
},
"Seq2Seq_dual": {
"input_size": 16,
"output_size": 1,
"hidden_size": 256,
"forecast_length": 168,
"cnn_size": 120,
"model_mode": "dual",
"input_size_encoder2": 1,
},
}
batch_size = 2

# Initialize the model
model_config = model_configs["Seq2Seq_dual"]

model = model_dict_function.pytorch_model_dict["Seq2Seq"](
input_size=model_config["input_size"],
output_size=model_config["output_size"],
hidden_size=model_config["hidden_size"],
forecast_length=model_config["forecast_length"],
cnn_size=model_config["cnn_size"],
model_mode=model_config["model_mode"],
)

# Generate random inputs for testing
# sourcery skip: no-conditionals-in-tests
if model_config["model_mode"] == "single":
src1 = torch.rand(
batch_size,
model_config["forecast_length"],
model_config["input_size"] - 1,
)
src2 = torch.rand(
batch_size, model_config["forecast_length"], model_config["cnn_size"]
)
else: # "dual"
src1 = torch.rand(
batch_size, model_config["forecast_length"], model_config["input_size"]
)
src2 = torch.rand(
batch_size,
model_config["forecast_length"],
model_config["input_size_encoder2"],
)

trg_start_token = torch.rand(batch_size, 1, model_config["output_size"])

# Execute the model
outputs = model(src1, src2, trg_start_token) # Adjusted to pass parameters directly

# Print outputs to verify
print("Outputs shape:", outputs.shape)
assert outputs.shape == (
batch_size,
model_config["forecast_length"],
model_config["output_size"],
), "Output shape is incorrect"

print("Test passed with output shape:", outputs.shape)
2 changes: 1 addition & 1 deletion tests/test_train_mean_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def config():
"hidden_size": 256,
"forecast_length": 24,
},
model_loader={"load_way": "best"},
gage_id=[
"21401550",
],
Expand Down Expand Up @@ -64,7 +65,6 @@ def config():
scaler="DapengScaler",
train_epoch=50,
save_epoch=1,
te=50,
train_period=[
("2017-07-01", "2017-09-29"),
("2018-07-01", "2018-09-29"),
Expand Down
112 changes: 112 additions & 0 deletions tests/test_train_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:55:24
LastEditTime: 2024-04-17 13:31:16
LastEditors: Xinzhuo Wu
Description:
FilePath: /torchhydro/tests/test_train_seq2seq.py
Copyright (c) 2021-2024 Wenyu Ouyang. All rights reserved.
"""

import pytest
from torchhydro.configs.config import cmd, default_config_file, update_cfg
from torchhydro.trainers.trainer import train_and_evaluate, ensemble_train_and_evaluate

import logging

logging.basicConfig(level=logging.INFO)
for logger_name in logging.root.manager.loggerDict:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)


@pytest.fixture()
def config():
project_name = "test_mean_seq2seq/ex1"
config_data = default_config_file()
args = cmd(
sub=project_name,
source_cfgs={
"source": "HydroMean",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/mean_data_forcing",
"target": "basins-origin/hour_data/1h/mean_data/mean_data_target",
"attributes": "basins-origin/attributes.nc",
},
},
ctx=[1],
model_name="Seq2Seq",
model_hyperparam={
"input_size": 16,
"output_size": 1,
"hidden_size": 256,
"cnn_size": 120,
"forecast_length": 24,
"model_mode": "dual",
},
model_loader={"load_way": "best"},
gage_id=[
"21401550",
],
batch_size=256,
rho=168,
var_t=["gpm_tp", "gfs_tp"],
var_c=[
"area", # 面积
"ele_mt_smn", # 海拔(空间平均)
"slp_dg_sav", # 地形坡度 (空间平均)
"sgr_dk_sav", # 河流坡度 (平均)
"for_pc_sse", # 森林覆盖率
"glc_cl_smj", # 土地覆盖类型
"run_mm_syr", # 陆面径流 (流域径流的空间平均值)
"inu_pc_slt", # 淹没范围 (长期最大)
"cmi_ix_syr", # 气候湿度指数
"aet_mm_syr", # 实际蒸散发 (年平均)
"snw_pc_syr", # 雪盖范围 (年平均)
"swc_pc_syr", # 土壤水含量
"gwt_cm_sav", # 地下水位深度
"cly_pc_sav", # 土壤中的黏土、粉砂、砂粒含量
"dor_pc_pva", # 调节程度
],
var_out=["streamflow"],
dataset="MultiSourceDataset",
sampler="HydroSampler",
scaler="DapengScaler",
train_epoch=1,
save_epoch=1,
train_period=[
("2017-07-01", "2017-09-29"),
("2018-07-01", "2018-09-29"),
("2019-07-01", "2019-09-29"),
("2020-07-01", "2020-09-29"),
],
test_period=[
("2021-07-01", "2021-09-29"),
],
valid_period=[
("2021-07-01", "2021-09-29"),
],
loss_func="RMSESum",
opt="Adam",
lr_scheduler={
"lr": 0.001,
"lr_factor": 0.96,
},
which_first_tensor="batch",
rolling=True,
static=False,
early_stopping=True,
patience=4,
ensemble=True,
ensemble_items={
"kfold": 5,
"batch_sizes": [256],
},
)
update_cfg(config_data, args)
return config_data


def test_seq2seq(config):
# train_and_evaluate(config)
ensemble_train_and_evaluate(config)
6 changes: 6 additions & 0 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,12 @@ def update_cfg(cfg_file, new_args):
cfg_file["data_cfgs"]["forecast_length"] = new_args.model_hyperparam[
"forecast_length"
]
if "model_mode" in new_args.model_hyperparam.keys():
cfg_file["data_cfgs"]["model_mode"] = new_args.model_hyperparam[
"model_mode"
]
if "cnn_size" in new_args.model_hyperparam.keys():
cfg_file["data_cfgs"]["cnn_size"] = new_args.model_hyperparam["cnn_size"]
if new_args.metrics is not None:
cfg_file["evaluation_cfgs"]["metrics"] = new_args.metrics
if new_args.fill_nan is not None:
Expand Down
2 changes: 2 additions & 0 deletions torchhydro/datasets/data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FlexibleDataset,
HydroGridDataset,
HydroMeanDataset,
HydroMultiSourceDataset,
)


Expand All @@ -25,4 +26,5 @@
"GridDataset": HydroGridDataset,
"MeanDataset": HydroMeanDataset,
"FlexDataset": FlexibleDataset,
"MultiSourceDataset": HydroMultiSourceDataset,
}
128 changes: 128 additions & 0 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,131 @@ def _prepare_forcing(self, data_type):
data_dict[basin] = merged_dataset.to_array(dim="variable")

return data_dict


class HydroMultiSourceDataset(HydroMeanDataset):
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
super(HydroMultiSourceDataset, self).__init__(data_cfgs, is_tra_val_te)

def get_x(self, variable, basin, time, seq_length, offset1=0, offset2=0):
return (
self.x.sel(
variable=variable,
basin=basin,
time=slice(
time
- np.timedelta64(seq_length - 1, "h")
+ np.timedelta64(offset1, "h"),
(time if offset2 == 0 else time + np.timedelta64(offset2, "h")),
),
)
.to_numpy()
.T.reshape(-1, 1)
)

def get_y(self, basin, time, seq_length):
return (
self.y.sel(
basin=basin,
time=slice(
time,
time + np.timedelta64(seq_length - 1, "h"),
),
)
.to_numpy()
.T
)

# Todo
def get_token(self, basin, time):
data = (
self.y.sel(
basin=basin,
time=slice(
time - np.timedelta64(1, "h"), time - np.timedelta64(1, "h")
),
)
.to_numpy()
.T
)

if data.shape[0] == 0:
data = self.y.sel(basin=basin, time=slice(time, time)).to_numpy().T

return data

def __getitem__(self, item: int):
basin, time = self.lookup_table[item]
seq_length = self.rho
sm_length = seq_length - self.data_cfgs["cnn_size"]
x, y, s = None, None, None

var_lst = self.data_cfgs["relevant_cols"]
station_tp_present = "station_tp" in var_lst

if station_tp_present:
gpm_tp = self.get_x("gpm_tp", basin, time, seq_length, 0, -6)
station_tp = self.get_x("station_tp", basin, time, seq_length)
expanded_gpm_tp = np.vstack([gpm_tp, station_tp[-6:, :]])
x = np.hstack([expanded_gpm_tp, station_tp])
if "streamflow" in var_lst:
yy = self.get_y("streamflow", basin, time, seq_length)
x = np.hstack([yy, x])
else:
x = self.get_x("gpm_tp", basin, time, seq_length)

if self.c is not None and self.c.shape[-1] > 0:
c = self.c.sel(basin=basin).values
c = np.tile(c, (seq_length, 1))
x = np.concatenate((x, c), axis=1)

mode = self.data_cfgs["model_mode"]
if mode == "dual":
s = self.get_x("gfs_tp", basin, time, seq_length, 0, -sm_length)
else:
all_features = [
self.get_x(
"gfs_tp", basin, time, seq_length, -offset, -sm_length - offset
)
for offset in range(seq_length)
]
s = np.array(
[
feat.squeeze()
for feat in all_features
if feat.size == seq_length - sm_length
]
)
start_token = self.get_token(basin, time)
y = self.get_y(basin, time, self.forecast_length)
return [
torch.from_numpy(x).float(),
torch.from_numpy(s).float(),
torch.from_numpy(start_token).float(),
], torch.from_numpy(y).float()

def _prepare_forcing(self):
gage_id_lst = self.t_s_dict["sites_id"]
t_range = self.t_s_dict["t_final_range"]
var_lst = self.data_cfgs["relevant_cols"]
path = self.data_cfgs["source_cfgs"]["source_path"]["forcing"]

if var_lst is None:
return None

data = self.data_source.merge_nc_minio_datasets(path, gage_id_lst, var_lst)

var_subset_list = []
for start_date, end_date in t_range:
adjusted_start_date = (
datetime.strptime(start_date, "%Y-%m-%d")
- timedelta(hours=self.rho + self.data_cfgs["cnn_size"])
).strftime("%Y-%m-%d")
adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d")
+ timedelta(hours=self.forecast_length)
).strftime("%Y-%m-%d")
subset = data.sel(time=slice(adjusted_start_date, adjusted_end_date))
var_subset_list.append(subset)

return xr.concat(var_subset_list, dim="time")
3 changes: 2 additions & 1 deletion torchhydro/models/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from torchhydro.models.simple_lstm import SimpleLSTMForecast

from torchhydro.models.seq2seq import GeneralSeq2Seq
from torch.optim import Adam, SGD, Adadelta
from torchhydro.models.crits import (
RMSELoss,
Expand Down Expand Up @@ -49,6 +49,7 @@
"SPPLSTM": SPP_LSTM_Model,
"SimpleLSTMForecast": SimpleLSTMForecast,
"SPPLSTM2": SPP_LSTM_Model_2,
"Seq2Seq": GeneralSeq2Seq,
}

pytorch_criterion_dict = {
Expand Down
Loading

0 comments on commit 62dd813

Please sign in to comment.