Skip to content

Commit

Permalink
add cpu package for env ffile; delete hydromeandataset and only use s…
Browse files Browse the repository at this point in the history
…eq2seqdataset
  • Loading branch information
OuyangWenyu committed Nov 3, 2024
1 parent 65b3c7a commit 8d44460
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 202 deletions.
1 change: 1 addition & 0 deletions env-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pytorch<2.5
- torchvision
- torchaudio
# - cpuonly # cpu version for pcs without cuda
- pytorch-cuda=11.8
- tensorboard
- kaggle
Expand Down
16 changes: 8 additions & 8 deletions tests/test_seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:55:24
LastEditTime: 2024-11-01 18:02:29
LastEditTime: 2024-11-02 21:47:38
LastEditors: Wenyu Ouyang
Description: Test funcs for seq2seq model
FilePath: \torchhydro\tests\test_seq2seq.py
FilePath: /torchhydro/tests/test_seq2seq.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

Expand Down Expand Up @@ -40,7 +40,7 @@


@pytest.fixture()
def config():
def seq2seq_config():
project_name = os.path.join("train_with_gpm", "ex_test")
config_data = default_config_file()
args = cmd(
Expand Down Expand Up @@ -96,9 +96,9 @@ def config():
train_epoch=2,
save_epoch=1,
train_mode=True,
train_period=[("2016-06-01-01", "2016-08-01-01")],
test_period=[("2015-06-01-01", "2015-08-01-01")],
valid_period=[("2015-06-01-01", "2015-08-01-01")],
train_period=["2016-06-01-01", "2016-08-01-01"],
test_period=["2015-06-01-01", "2015-08-01-01"],
valid_period=["2015-06-01-01", "2015-08-01-01"],
loss_func="MultiOutLoss",
loss_param={
"loss_funcs": "RMSESum",
Expand Down Expand Up @@ -130,10 +130,10 @@ def config():
return config_data


def test_seq2seq(config):
def test_seq2seq(seq2seq_config):
# world_size = len(config["training_cfgs"]["device"])
# mp.spawn(train_worker, args=(world_size, config), nprocs=world_size, join=True)
train_and_evaluate(config)
train_and_evaluate(seq2seq_config)
# ensemble_train_and_evaluate(config)


Expand Down
104 changes: 0 additions & 104 deletions tests/test_train_mean_lstm.py

This file was deleted.

8 changes: 3 additions & 5 deletions torchhydro/datasets/data_dict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2024-05-24 14:52:00
LastEditTime: 2024-11-02 21:17:44
LastEditors: Wenyu Ouyang
Description: A dict used for data source and data loader
FilePath: \torchhydro\torchhydro\datasets\data_dict.py
FilePath: /torchhydro/torchhydro/datasets/data_dict.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""

Expand All @@ -13,7 +13,6 @@
BasinSingleFlowDataset,
DplDataset,
FlexibleDataset,
HydroMeanDataset,
Seq2SeqDataset,
TransformerDataset,
)
Expand All @@ -23,8 +22,7 @@
"StreamflowDataset": BaseDataset,
"SingleflowDataset": BasinSingleFlowDataset,
"DplDataset": DplDataset,
"MeanDataset": HydroMeanDataset,
"FlexDataset": FlexibleDataset,
"Seq2SeqDataset": Seq2SeqDataset,
"TransformerDataset":TransformerDataset,
"TransformerDataset": TransformerDataset,
}
90 changes: 5 additions & 85 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:16:53
LastEditTime: 2024-09-10 19:45:46
LastEditTime: 2024-11-02 21:59:06
LastEditors: Wenyu Ouyang
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
FilePath: /torchhydro/torchhydro/datasets/data_sets.py
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
"""

Expand Down Expand Up @@ -611,9 +611,9 @@ def _normalize(self):
return scaler_hub.x, scaler_hub.y, scaler_hub.c


class HydroMeanDataset(BaseDataset):
class Seq2SeqDataset(BaseDataset):
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
super(HydroMeanDataset, self).__init__(data_cfgs, is_tra_val_te)
super(Seq2SeqDataset, self).__init__(data_cfgs, is_tra_val_te)

@property
def data_source(self):
Expand All @@ -627,92 +627,12 @@ def data_source(self):

def _normalize(self):
x, y, c = super()._normalize()
# TODO: this work for minio? maybe better to move to basedataset
return x.compute(), y.compute(), c.compute()

def _read_xyc(self):
data_target_ds = self._prepare_target()
if data_target_ds is not None:
y_origin = self._trans2da_and_setunits(data_target_ds)
else:
y_origin = None

data_forcing_ds = self._prepare_forcing()
if data_forcing_ds is not None:
x_origin = self._trans2da_and_setunits(data_forcing_ds)
else:
x_origin = None

if self.data_cfgs["constant_cols"]:
data_attr_ds = self.data_source.read_attr_xrdataset(
self.t_s_dict["sites_id"],
self.data_cfgs["constant_cols"],
# self.data_cfgs["source_cfgs"]["source_path"]["attributes"],
)
c_orgin = self._trans2da_and_setunits(data_attr_ds)
else:
c_orgin = None
self.x_origin, self.y_origin, self.c_origin = x_origin, y_origin, c_orgin

def __len__(self):
return self.num_samples

def _prepare_forcing(self):
return self._read_from_minio(self.data_cfgs["relevant_cols"])

def _prepare_target(self):
return self._read_from_minio(self.data_cfgs["target_cols"])

def _read_from_minio(self, var_lst):
gage_id_lst = self.t_s_dict["sites_id"]
t_range = self.t_s_dict["t_final_range"]
interval = self.data_cfgs["min_time_interval"]
time_unit = (
str(self.data_cfgs["min_time_interval"]) + self.data_cfgs["min_time_unit"]
)

subset_list = []
for start_date, end_date in t_range:
adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d-%H") + timedelta(hours=interval)
).strftime("%Y-%m-%d-%H")
subset = self.data_source.read_ts_xrdataset(
gage_id_lst,
t_range=[start_date, adjusted_end_date],
var_lst=var_lst,
time_units=[time_unit],
)
subset_list.append(subset[time_unit])
return xr.concat(subset_list, dim="time")


class Seq2SeqDataset(HydroMeanDataset):
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
super(Seq2SeqDataset, self).__init__(data_cfgs, is_tra_val_te)

def _read_xyc(self):
data_target_ds = self._prepare_target()
if data_target_ds is not None:
y_origin = self._trans2da_and_setunits(data_target_ds)
else:
y_origin = None

data_forcing_ds = self._prepare_forcing()
if data_forcing_ds is not None:
x_origin = self._trans2da_and_setunits(data_forcing_ds)
x_origin = xr.where(x_origin < 0, float("nan"), x_origin)
else:
x_origin = None

if self.data_cfgs["constant_cols"]:
data_attr_ds = self.data_source.read_attr_xrdataset(
self.t_s_dict["sites_id"],
self.data_cfgs["constant_cols"],
)
c_orgin = self._trans2da_and_setunits(data_attr_ds)
else:
c_orgin = None
self.x_origin, self.y_origin, self.c_origin = x_origin, y_origin, c_orgin

def __getitem__(self, item: int):
basin, idx = self.lookup_table[item]
rho = self.rho
Expand Down

0 comments on commit 8d44460

Please sign in to comment.