Skip to content

Commit

Permalink
Merge remote-tracking branch 'downstream/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed May 22, 2024
2 parents 5bbe272 + 28c1a5b commit 090b695
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 90 deletions.
35 changes: 21 additions & 14 deletions tests/test_train_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
import pandas as pd
import pytest
import hydrodatasource.configs.config as hdscc
import torch
import xarray as xr
import torch.multiprocessing as mp

from torchhydro.configs.config import cmd, default_config_file, update_cfg
from torchhydro.trainers.trainer import train_and_evaluate, ensemble_train_and_evaluate
from torchhydro.trainers.deep_hydro import train_worker

# from torchhydro.trainers.trainer import train_and_evaluate, ensemble_train_and_evaluate

logging.basicConfig(level=logging.INFO)
for logger_name in logging.root.manager.loggerDict:
Expand Down Expand Up @@ -58,7 +62,7 @@ def config():
"attributes": "basins-origin/attributes.nc",
},
},
ctx=[1],
ctx=[0, 1, 2],
model_name="Seq2Seq",
model_hyperparam={
"input_size": 20,
Expand All @@ -70,17 +74,17 @@ def config():
model_loader={"load_way": "best"},
gage_id=[
"21401550", # 碧流河
"01181000",
"01411300",
"01414500",
"02016000",
"02018000",
"02481510",
"03070500",
# "01181000",
# "01411300",
# "01414500",
# "02016000",
# "02018000",
# "02481510",
# "03070500",
# "08324000",
"11266500",
# "11266500",
# "11523200",
"12020000",
# "12020000",
# "12167000",
"14185000",
"14306500",
Expand Down Expand Up @@ -119,7 +123,7 @@ def config():
train_epoch=2,
save_epoch=1,
train_period=[
("2015-06-01", "2022-12-20"),
("2015-06-01", "2016-12-20"),
],
test_period=[
("2023-02-01", "2023-11-30"),
Expand Down Expand Up @@ -150,12 +154,15 @@ def config():
"kfold": 9,
"batch_sizes": [1024],
},
model_type="MTL",
model_type="DDP_MTL",
fill_nan=['no', 'no']
)
update_cfg(config_data, args)
return config_data


def test_seq2seq(config):
train_and_evaluate(config)
world_size = torch.cuda.device_count()
mp.spawn(train_worker, args=(world_size, config), nprocs=world_size, join=True)
# train_and_evaluate(config)
# ensemble_train_and_evaluate(config)
22 changes: 22 additions & 0 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def default_config_file():
"static": True,
},
"training_cfgs": {
"master_addr": "localhost",
"port": "12335",
# if train_mode is False, don't train and evaluate
"train_mode": True,
"criterion": "RMSE",
Expand Down Expand Up @@ -292,6 +294,8 @@ def cmd(
fl_local_ep=None,
fl_local_bs=None,
fl_frac=None,
master_addr=None,
port=None,
ctx=None,
rs=None,
gage_id_file=None,
Expand Down Expand Up @@ -422,6 +426,20 @@ def cmd(
default=fl_frac,
type=float,
)
parser.add_argument(
"--master_addr",
dest="master_addr",
help="Running Context --default is localhost if you only train model on your computer",
default=master_addr,
nargs="+",
)
parser.add_argument(
"--port",
dest="port",
help="Running Context -- which port do you want to use when using DistributedDataParallel",
default=port,
nargs="+",
)
parser.add_argument(
"--ctx",
dest="ctx",
Expand Down Expand Up @@ -848,6 +866,10 @@ def update_cfg(cfg_file, new_args):
cfg_file["model_cfgs"]["fl_hyperparam"]["fl_local_bs"] = new_args.fl_local_bs
if new_args.fl_frac is not None:
cfg_file["model_cfgs1"]["fl_hyperparam"]["fl_frac"] = new_args.fl_frac
if new_args.master_addr is not None:
cfg_file["training_cfgs"]["master_addr"] = new_args.master_addr
if new_args.port is not None:
cfg_file["training_cfgs"]["port"] = new_args.port
if new_args.ctx is not None:
cfg_file["training_cfgs"]["device"] = new_args.ctx
if new_args.rs is not None:
Expand Down
135 changes: 124 additions & 11 deletions torchhydro/trainers/deep_hydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,38 @@
FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
"""

import copy
import os
from abc import ABC, abstractmethod
from collections import defaultdict
import copy
from functools import reduce
import os
from typing import Dict, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.nn as nn
from hydroutils.hydro_file import get_lastest_file_in_a_dir
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

from hydroutils.hydro_file import get_lastest_file_in_a_dir
from torchhydro.configs.config import update_nested_dict
from torchhydro.datasets.data_dict import datasets_dict
from torchhydro.datasets.sampler import (
KuaiSampler,
fl_sample_basin,
fl_sample_region,
HydroSampler,
)
from torchhydro.datasets.data_dict import datasets_dict
from torchhydro.models.model_dict_function import (
pytorch_criterion_dict,
pytorch_model_dict,
pytorch_opt_dict,
)
from torchhydro.models.model_utils import get_the_device
from torchhydro.trainers.train_logger import TrainLogger
from torchhydro.trainers.train_utils import (
EarlyStopper,
average_weights,
Expand All @@ -46,7 +49,6 @@
torch_single_train,
cellstates_when_inference,
)
from torchhydro.trainers.train_logger import TrainLogger


class DeepHydroInterface(ABC):
Expand Down Expand Up @@ -484,9 +486,7 @@ def _get_loss_func(self, training_cfgs):
if loss_param is not None:
for key in loss_param.keys():
if key == "loss_funcs":
criterion_init_params[key] = pytorch_criterion_dict[
loss_param[key]
]()
criterion_init_params[key] = pytorch_criterion_dict[loss_param[key]]()
else:
criterion_init_params[key] = loss_param[key]
return pytorch_criterion_dict[training_cfgs["criterion"]](
Expand Down Expand Up @@ -522,6 +522,8 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
ngrid=ngrid,
nt=nt,
)
elif data_cfgs['sampler'] == "DistSampler":
sampler = DistributedSampler(train_dataset)
else:
raise NotImplementedError("This sampler not implemented yet")
data_loader = DataLoader(
Expand Down Expand Up @@ -832,9 +834,120 @@ def _get_loss_func(self, training_cfgs):
)


class DistributedDeepHydro(MultiTaskHydro):
def __init__(self, world_size, cfgs: Dict):
super().__init__(cfgs, cfgs['model_cfgs']['weight_path'])
self.world_size = world_size
# self.cfgs = cfgs

def setup(self, rank):
os.environ['MASTER_ADDR'] = self.cfgs['training_cfgs']['master_addr']
os.environ['MASTER_PORT'] = self.cfgs['training_cfgs']['port']
dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
self.device = torch.device(rank)
self.rank = rank

def cleanup(self):
dist.destroy_process_group()

def load_model(self, mode='train'):
if mode == "infer":
if self.weight_path is None:
# for continue training
self.weight_path = self._get_trained_model()
else:
pass
elif mode != "train":
raise ValueError("Invalid mode; must be 'train' or 'infer'")
model_cfgs = self.cfgs["model_cfgs"]
model_name = model_cfgs["model_name"]
if model_name not in pytorch_model_dict:
raise NotImplementedError(
f"Error the model {model_name} was not found in the model dict. Please add it."
)
if self.pre_model is not None:
model = self._load_pretrain_model()
elif self.weight_path is not None:
# load model from pth file (saved weights and biases)
model = self._load_model_from_pth()
else:
model = pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
return model

def model_train(self):
model = self.load_model().to(self.device)
self.model = DDP(model, device_ids=[self.rank])
training_cfgs = self.cfgs["training_cfgs"]
# The file path to load model weights from; defaults to "model_save"
model_filepath = self.cfgs["data_cfgs"]["test_path"]
data_cfgs = self.cfgs["data_cfgs"]
es = None
if training_cfgs["early_stopping"]:
es = EarlyStopper(training_cfgs["patience"])
criterion = self._get_loss_func(training_cfgs)
opt = self._get_optimizer(training_cfgs)
scheduler = self._get_scheduler(training_cfgs, opt)
max_epochs = training_cfgs["epochs"]
start_epoch = training_cfgs["start_epoch"]
# use PyTorch's DataLoader to load the data into batches in each epoch
data_loader, validation_data_loader = self._get_dataloader(
training_cfgs, data_cfgs
)
logger = TrainLogger(model_filepath, self.cfgs, opt)
for epoch in range(start_epoch, max_epochs + 1):
with logger.log_epoch_train(epoch) as train_logs:
total_loss, n_iter_ep = torch_single_train(
self.model,
opt,
criterion,
data_loader,
device=self.device,
which_first_tensor=training_cfgs["which_first_tensor"],
)
train_logs["train_loss"] = total_loss
train_logs["model"] = self.model

valid_loss = None
valid_metrics = None
if data_cfgs["t_range_valid"] is not None:
with logger.log_epoch_valid(epoch) as valid_logs:
valid_loss, valid_metrics = self._1epoch_valid(
training_cfgs, criterion, validation_data_loader, valid_logs
)

self._scheduler_step(training_cfgs, scheduler, valid_loss)
logger.save_session_param(
epoch, total_loss, n_iter_ep, valid_loss, valid_metrics
)
logger.save_model_and_params(self.model, epoch, self.cfgs)
if es and not es.check_loss(
self.model,
valid_loss,
self.cfgs["data_cfgs"]["test_path"],
):
print("Stopping model now")
break
# if self.rank == 0:
# logging.log(1, f"Training complete in: {time.time() - start_time:.2f} seconds"

def run(self):
self.setup(self.rank)
self.model_train()
self.model_evaluate()
self.cleanup()


def train_worker(rank, world_size, cfgs):
trainer = DistributedDeepHydro(world_size, cfgs)
trainer.rank = rank
trainer.run()


model_type_dict = {
"Normal": DeepHydro,
"FedLearn": FedLearnHydro,
"TransLearn": TransLearnHydro,
"MTL": MultiTaskHydro,
"DDP_MTL": DistributedDeepHydro
}
Loading

0 comments on commit 090b695

Please sign in to comment.