Skip to content

Commit

Permalink
rename HydroSampler to new one; add _get_sampler func
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 4, 2024
1 parent 8d44460 commit 1fc92ff
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 71 deletions.
2 changes: 1 addition & 1 deletion experiments/evaluate_with_era5land.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_config_data():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
loss_func="MultiOutLoss",
loss_param={
Expand Down
2 changes: 1 addition & 1 deletion experiments/evaluate_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_config_data():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
loss_func="MultiOutLoss",
loss_param={
Expand Down
2 changes: 1 addition & 1 deletion experiments/evaluate_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_config_data():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
loss_func="MultiOutLoss",
loss_param={
Expand Down
2 changes: 1 addition & 1 deletion experiments/train_with_era5land.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def config():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=100,
save_epoch=1,
Expand Down
2 changes: 1 addition & 1 deletion experiments/train_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def config():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=100,
# train_mode=False,
Expand Down
1 change: 0 additions & 1 deletion experiments/train_with_gpm_dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def create_config():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="DistSampler",
scaler="DapengScaler",
train_epoch=2,
save_epoch=1,
Expand Down
2 changes: 1 addition & 1 deletion experiments/train_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def create_config():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=100,
save_epoch=1,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def s2s_args(basin4test):
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=1,
save_epoch=1,
Expand Down Expand Up @@ -330,7 +330,7 @@ def trans_args(basin4test):
],
var_out=["streamflow", "sm_surface"],
dataset="TransformerDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=10,
save_epoch=1,
Expand Down
31 changes: 29 additions & 2 deletions tests/test_deep_hydro.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-05-01 10:34:15
LastEditTime: 2024-09-18 16:47:26
LastEditTime: 2024-11-04 18:26:24
LastEditors: Wenyu Ouyang
Description: Unit tests for the DeepHydro class
FilePath: \torchhydro\tests\test_deep_hydro.py
Expand All @@ -12,12 +12,14 @@
import os
import torch
from torch.utils.data import Dataset
from torchhydro.datasets.sampler import BasinBatchSampler, KuaiSampler
from torchhydro.trainers.deep_hydro import DeepHydro
from torchhydro.datasets.data_dict import datasets_dict
from torchhydro.trainers.train_logger import TrainLogger
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR, ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DistributedSampler


# Mock dataset class using random data
Expand Down Expand Up @@ -63,6 +65,7 @@ def dummy_data_cfgs():
"forecast_history": 0,
"forecast_length": 30,
"warmup_length": 0,
"object_ids": ["02051500", "21401550"],
}


Expand Down Expand Up @@ -98,7 +101,7 @@ def dummy_train_cfgs(dummy_data_cfgs):
"epochs": 12,
"start_epoch": 1,
"which_first_tensor": "batch",
"device": -1, # Assuming CPU device
"device": [-1], # Assuming CPU device
"train_mode": True,
"criterion": "RMSE",
"optimizer": "Adam",
Expand Down Expand Up @@ -201,3 +204,27 @@ def test_get_scheduler_lambda_lr_with_epochs_show_lr(deep_hydro, dummy_train_cfg
print(f"epoch:{epoch}, lr:{opt.param_groups[0]['lr']}")
scheduler.step()
assert isinstance(scheduler, LambdaLR)


def test_get_sampler_basin_batch_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "BasinBatchSampler"}
sampler = deep_hydro._get_sampler(
dummy_train_cfgs["data_cfgs"], deep_hydro.traindataset
)
assert isinstance(sampler, BasinBatchSampler)


def test_get_sampler_kuai_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "KuaiSampler"}
sampler = deep_hydro._get_sampler(
dummy_train_cfgs["data_cfgs"], deep_hydro.traindataset
)
assert isinstance(sampler, KuaiSampler)


def test_get_sampler_invalid_sampler(deep_hydro, dummy_train_cfgs):
dummy_train_cfgs["data_cfgs"]["sampler"] = {"name": "InvalidSampler"}
with pytest.raises(
NotImplementedError, match="Sampler InvalidSampler not implemented yet"
):
deep_hydro._get_sampler(dummy_train_cfgs["data_cfgs"], deep_hydro.traindataset)
2 changes: 1 addition & 1 deletion tests/test_evaluate_grid_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def config_data():
"dor_pc_pva", # 调节程度
],
dataset="GridDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="MutiBasinScaler",
test_period=[
("2017-07-01", "2017-09-29"),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_evaluate_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def config_data():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
loss_func="MultiOutLoss",
loss_param={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pretrain_dataenhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def config():
("2023-06-01", "2023-09-30"), # 目前只支持一个时段
],
dataset="MultiSourceDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
weight_path=weight_path,
weight_path_add={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pretrain_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def config():
("2023-07-08", "2023-09-29"),
],
dataset="MultiSourceDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
weight_path=weight_path,
weight_path_add={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def seq2seq_config():
],
var_out=["streamflow", "sm_surface"],
dataset="Seq2SeqDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="DapengScaler",
train_epoch=2,
save_epoch=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train_grid_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def config():
],
var_out=["streamflow"],
dataset="GridDataset",
sampler="HydroSampler",
sampler="BasinBatchSampler",
scaler="MutiBasinScaler",
train_epoch=1,
save_epoch=1,
Expand Down
47 changes: 29 additions & 18 deletions torchhydro/datasets/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2023-09-25 08:21:27
LastEditTime: 2024-05-27 15:59:09
LastEditTime: 2024-11-04 18:16:08
LastEditors: Wenyu Ouyang
Description: Some sampling class or functions
FilePath: \torchhydro\torchhydro\datasets\sampler.py
Expand All @@ -12,14 +12,14 @@
import numpy as np
from torch.utils.data import RandomSampler, Sampler
from torchhydro.datasets.data_sets import BaseDataset
from typing import Iterator, Optional, Sized
from typing import Iterator, Optional
import torch


class KuaiSampler(RandomSampler):
def __init__(
self,
data_source,
dataset,
batch_size,
warmup_length,
rho_horizon,
Expand All @@ -32,7 +32,7 @@ def __init__(
Parameters
----------
data_source : torch.utils.data.Dataset
dataset : torch.utils.data.Dataset
just a object of dataset class inherited from torch.utils.data.Dataset
batch_size : int
we need batch_size to calculate the number of samples in an epoch
Expand Down Expand Up @@ -60,36 +60,39 @@ def __init__(
# __len__ means the number of all samples, then, the number of loops in an epoch is __len__()/batch_size = n_iter_ep
# hence we return n_iter_ep * batch_size
num_samples = n_iter_ep * batch_size
super(KuaiSampler, self).__init__(data_source, num_samples=num_samples)
super(KuaiSampler, self).__init__(dataset, num_samples=num_samples)


class HydroSampler(Sampler[int]):
class BasinBatchSampler(Sampler[int]):
"""
A custom sampler for hydrological modeling that iterates over a dataset in
a way tailored for batches of hydrological data. It ensures that each batch
contains data from a single randomly selected 'basin' out of several basins,
with batches constructed to respect the specified batch size and the unique
characteristics of hydrological datasets.
TODO: made by Xinzhuo Wu, maybe need to be tested more
Parameters:
- data_source (Sized): The dataset to sample from, expected to have a `data_cfgs` attribute.
- num_samples (Optional[int], default=None): The total number of samples to draw (optional).
- generator: A PyTorch Generator object for random number generation (optional).
Parameters
----------
dataset : BaseDataset
The dataset to sample from, expected to have a `data_cfgs` attribute.
num_samples : Optional[int], default=None
The total number of samples to draw (optional).
generator : Optional[torch.Generator]
A PyTorch Generator object for random number generation (optional).
The sampler divides the dataset by the number of basins, then iterates through
each basin's range in shuffled order, ensuring non-overlapping, basin-specific
batches suitable for models that predict hydrological outcomes.
"""

data_source: Sized

def __init__(
self,
data_source: Sized,
dataset,
num_samples: Optional[int] = None,
generator=None,
) -> None:
self.data_source = data_source
self.dataset = dataset
self._num_samples = num_samples
self.generator = generator

Expand All @@ -100,12 +103,12 @@ def __init__(

@property
def num_samples(self) -> int:
return len(self.data_source)
return len(self.dataset)

def __iter__(self) -> Iterator[int]:
n = self.data_source.data_cfgs["batch_size"]
basin_number = len(self.data_source.data_cfgs["object_ids"])
basin_range = len(self.data_source) // basin_number
n = self.dataset.data_cfgs["batch_size"]
basin_number = len(self.dataset.data_cfgs["object_ids"])
basin_range = len(self.dataset) // basin_number
if n > basin_range:
raise ValueError(
f"batch_size should equal or less than basin_range={basin_range} "
Expand Down Expand Up @@ -200,3 +203,11 @@ def fl_sample_region(dataset: BaseDataset):
(dict_users[i], idxs[rand * num_imgs : (rand + 1) * num_imgs]), axis=0
)
return dict_users


data_sampler_dict = {
"KuaiSampler": KuaiSampler,
"BasinBatchSampler": BasinBatchSampler,
# TODO: DistributedSampler need more test
# "DistSampler": DistributedSampler,
}
Loading

0 comments on commit 1fc92ff

Please sign in to comment.