Skip to content

Commit

Permalink
delete data_source property in seq2seq dataset, and directly use base…
Browse files Browse the repository at this point in the history
…dataset's, but note that config need has a time_unit setting
  • Loading branch information
OuyangWenyu committed Nov 5, 2024
1 parent baac385 commit e557287
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 49 deletions.
10 changes: 6 additions & 4 deletions experiments/evaluate_with_era5land.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Author: Wenyu Ouyang
Author: Xinzhuo Wu
Date: 2024-05-22 09:36:49
LastEditTime: 2024-05-27 15:47:53
LastEditTime: 2024-11-05 11:37:48
LastEditors: Wenyu Ouyang
Description:
Description: evaluate with era5land
FilePath: \torchhydro\experiments\evaluate_with_era5land.py
Copyright (c) 2021-2024 Wenyu Ouyang. All rights reserved.
"""
Expand Down Expand Up @@ -32,13 +32,15 @@ def get_config_data():
train_path = os.path.join(os.getcwd(), "results", "train_with_era5land", "ex20")
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_era5land_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_era5land_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion experiments/evaluate_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def get_config_data():
train_path = os.path.join(os.getcwd(), "results", "train_with_gpm", "ex31")
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion experiments/evaluate_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def get_config_data():
)
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
8 changes: 5 additions & 3 deletions experiments/train_with_era5land.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:55:24
LastEditTime: 2024-05-27 10:10:35
LastEditTime: 2024-11-05 11:40:28
LastEditors: Wenyu Ouyang
Description:
FilePath: \torchhydro\tests\test_train_seq2seq.py
FilePath: \torchhydro\experiments\train_with_era5land.py
Copyright (c) 2021-2024 Wenyu Ouyang. All rights reserved.
"""

Expand Down Expand Up @@ -48,9 +48,11 @@ def config():
# 填充测试所需的命令行参数
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": "/ftproot/basins-interim/",
"other_settings": {"time_unit": ["3h"]},
},
ctx=[2],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion experiments/train_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ def config():
# 填充测试所需的命令行参数
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": "/ftproot/basins-interim/",
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion experiments/train_with_gpm_dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def create_config():
config_data = default_config_file()
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0, 1, 2],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion experiments/train_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def create_config():
# 填充测试所需的命令行参数
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[1],
model_name="Seq2Seq",
Expand Down
13 changes: 10 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ def s2s_args(basin4test):
project_name = os.path.join("test_seq2seq", "gpmsmapexp1")
return cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down Expand Up @@ -280,13 +282,15 @@ def trans_args(basin4test):
project_name = os.path.join("test_trans", "gpmsmapexp1")
return cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Transformer",
Expand Down Expand Up @@ -464,8 +468,11 @@ def seq2seq_config():
args = cmd(
sub=project_name,
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": SETTING["local_data_path"]["datasets-interim"],
"other_settings": {
"time_unit": ["3h"],
},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
138 changes: 135 additions & 3 deletions tests/test_data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-05-27 13:33:08
LastEditTime: 2024-05-27 15:31:14
LastEditTime: 2024-11-05 11:46:01
LastEditors: Wenyu Ouyang
Description: Unit test for datasets
FilePath: \torchhydro\tests\test_data_sets.py
Expand All @@ -12,15 +12,16 @@
import os
import numpy as np
import pandas as pd
import torch
import xarray as xr
import pickle
from sklearn.preprocessing import StandardScaler
from torchhydro.datasets.data_sets import BaseDataset
from torchhydro.datasets.data_sets import BaseDataset, Seq2SeqDataset
from torchhydro.datasets.data_sources import data_sources_dict


class MockDatasource:
def __init__(self, source_cfgs):
def __init__(self, source_cfgs, time_unit="1D"):
self.ngrid = 2
self.nt = 366
self.data_cfgs = source_cfgs
Expand Down Expand Up @@ -148,3 +149,134 @@ def test_create_lookup_table(tmp_path):
isinstance(key, int) and isinstance(value, tuple)
for key, value in lookup_table.items()
)


def test_seq2seqdataset_getitem_train_mode(monkeypatch):
# 模拟 os.listdir 返回值
def mock_listdir(path):
if "timeseries" in path:
return ["1D"]
elif "attributes" in path:
return ["attributes.csv"]
return []

# 模拟 os.path.isdir 返回值
def mock_isdir(path):
return True

# 模拟 pandas.read_csv 返回值
def mock_read_csv(filepath, *args, **kwargs):
if "attributes.csv" in filepath:
return pd.DataFrame(
{
"basin_id": ["01013500", "01013501"],
"area": [100.0, 200.0],
"attr1": [1.0, 2.0],
"attr2": [3.0, 4.0],
}
)
elif "timeseries" in filepath:
return pd.DataFrame(
{
"date": pd.date_range(start="2001-01-01", periods=365, freq="D"),
"prcp": [0.1] * 365,
"pet": [0.05] * 365,
"streamflow": [1.0] * 365,
"surface_sm": [0.3] * 365,
}
)
return pd.DataFrame()

# 模拟 os.path.exists 返回值
def mock_exists(path):
return True

# 模拟 os.path.join 返回值
def mock_join(a, *p):
return a + "/" + "/".join(p)

# 使用 monkeypatch 模拟 os.listdir、os.path.isdir、os.path.exists、os.path.join 和 pandas.read_csv
monkeypatch.setattr(os, "listdir", mock_listdir)
monkeypatch.setattr(os.path, "isdir", mock_isdir)
monkeypatch.setattr(os.path, "exists", mock_exists)
monkeypatch.setattr(os.path, "join", mock_join)
monkeypatch.setattr(pd, "read_csv", mock_read_csv)

data_sources_dict.update({"mockdatasource": MockDatasource})
data_cfgs = {
"source_cfgs": {
"source_name": "mockdatasource",
"source_path": "mock_path",
"other_settings": {"time_unit": ["1D"]},
},
"object_ids": ["01013500", "01013501"],
"t_range_train": ["2001-01-01", "2002-01-01"],
"t_range_test": ["2002-01-01", "2003-01-01"],
"relevant_cols": ["prcp", "pet"],
"target_cols": ["streamflow", "surface_sm"],
"constant_cols": ["geol_1st_class", "geol_2nd_class"],
"forecast_history": 7,
"warmup_length": 14,
"forecast_length": 1,
"min_time_unit": "D",
"min_time_interval": 1,
"target_rm_nan": True,
"relevant_rm_nan": True,
"constant_rm_nan": True,
"prec_window": 3,
"en_output_size": 5,
}
is_tra_val_te = "train"
dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te)
item = 0
(x, x_h, y), y_out = dataset[item]
assert isinstance(x, torch.Tensor)
assert isinstance(x_h, torch.Tensor)
assert isinstance(y, torch.Tensor)
assert isinstance(y_out, torch.Tensor)
assert (
x.shape[1]
== len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1
)
assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1
assert y.shape[1] == len(data_cfgs["target_cols"])
assert y_out.shape == y.shape


def test_seq2seqdataset_getitem_test_mode():
data_sources_dict.update({"mockdatasource": MockDatasource})
data_cfgs = {
"source_cfgs": {
"source_name": "mockdatasource",
"source_path": "mock_path",
},
"object_ids": ["01013500", "01013501"],
"t_range_train": ["2001-01-01", "2002-01-01"],
"t_range_test": ["2002-01-01", "2003-01-01"],
"relevant_cols": ["prcp", "pet"],
"target_cols": ["streamflow", "surface_sm"],
"constant_cols": ["geol_1st_class", "geol_2nd_class"],
"forecast_history": 7,
"warmup_length": 14,
"forecast_length": 1,
"min_time_unit": "D",
"min_time_interval": 1,
"target_rm_nan": True,
"relevant_rm_nan": True,
"constant_rm_nan": True,
"prec_window": 3,
"en_output_size": 5,
}
is_tra_val_te = "test"
dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te)
item = 0
(x, x_h), y = dataset[item]
assert isinstance(x, torch.Tensor)
assert isinstance(x_h, torch.Tensor)
assert isinstance(y, torch.Tensor)
assert (
x.shape[1]
== len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1
)
assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1
assert y.shape[1] == len(data_cfgs["target_cols"])
4 changes: 3 additions & 1 deletion tests/test_evaluate_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def config_data():
train_path = os.path.join(os.getcwd(), "results", "train_with_era5land", "ex20")
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_era5land_streamflow",
"target": "basins-origin/hour_data/1h/mean_data/data_forcing_era5land_streamflow",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[0],
model_name="Seq2Seq",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pretrain_dataenhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def config():
project_name = "test_pretrain_enhanced/ex1"
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/data_forcing_gpm",
"target": "basins-origin/hour_data/1h/mean_data/streamflow_basin",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[1],
model_type="TransLearn",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pretrain_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def config():
project_name = "test_pretrain_fusion/exp1"
args = cmd(
sub=project_name,
# TODO: Update the source_path to the correct path
source_cfgs={
"source": "HydroMean",
"source_name": "selfmadehydrodataset",
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/mean_data_merged",
"target": "basins-origin/hour_data/1h/mean_data/mean_data_merged",
"attributes": "basins-origin/attributes.nc",
},
"other_settings": {"time_unit": ["3h"]},
},
ctx=[1],
model_type="TransLearn",
Expand Down
Loading

0 comments on commit e557287

Please sign in to comment.