diff --git a/tests/test_train_camels_lstm.py b/tests/test_train_camels_lstm.py index 94874cf..cc05c89 100644 --- a/tests/test_train_camels_lstm.py +++ b/tests/test_train_camels_lstm.py @@ -1,13 +1,14 @@ """ Author: Wenyu Ouyang Date: 2023-07-25 16:47:19 -LastEditTime: 2024-04-10 21:00:10 +LastEditTime: 2024-10-29 14:29:55 LastEditors: Wenyu Ouyang Description: Test a full training and evaluating process FilePath: \torchhydro\tests\test_train_camels_lstm.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. """ +import os from torchhydro.configs.config import update_cfg from torchhydro.trainers.trainer import train_and_evaluate @@ -15,3 +16,27 @@ def test_train_evaluate(args, config_data): update_cfg(config_data, args) train_and_evaluate(config_data) + + +def test_train_evaluate_continue(args, config_data): + """We test the training and evaluation process with the continue_train + flag set to 1 and the start_epoch set to 2. This will load a pretrained + model and continue training. + This pattern is useful for training a model + when its training is interrupted + + Parameters + ---------- + args : _type_ + basic args in conftest.py + config_data : _type_ + default config data + """ + args.continue_train = 1 + args.start_epoch = 2 + args.train_mode = 1 + update_cfg(config_data, args) + config_data["model_cfgs"]["weight_path"] = os.path.join( + config_data["data_cfgs"]["test_path"], "model_Ep1.pth" + ) + train_and_evaluate(config_data) diff --git a/torchhydro/trainers/train_logger.py b/torchhydro/trainers/train_logger.py index 82c9193..36e50f0 100644 --- a/torchhydro/trainers/train_logger.py +++ b/torchhydro/trainers/train_logger.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 -LastEditTime: 2024-09-18 15:40:10 +LastEditTime: 2024-10-29 16:07:08 LastEditors: Wenyu Ouyang Description: Training function for DL models FilePath: \torchhydro\torchhydro\trainers\train_logger.py @@ -18,6 +18,8 @@ import torch from torch.utils.tensorboard import SummaryWriter +from torchhydro.trainers.train_utils import get_lastest_logger_file_in_a_dir + def save_model(model, model_file, gpu_num=1): try: @@ -48,6 +50,22 @@ def __init__(self, model_filepath, params, opt): self.train_time = [] # log loss for each epoch self.epoch_loss = [] + # reload previous logs if continue_train is True and weight_path is not None + if ( + self.model_cfgs["continue_train"] + and self.model_cfgs["weight_path"] is not None + ): + the_logger_file = get_lastest_logger_file_in_a_dir(self.training_save_dir) + if the_logger_file is not None: + with open(the_logger_file, "r") as f: + logs = json.load(f) + start_epoch = self.training_cfgs["start_epoch"] + # read the logs before start_epoch and load them to session_params, train_time, epoch_loss + for log in logs["run"]: + if log["epoch"] < start_epoch: + self.session_params.append(log) + self.train_time.append(log["train_time"]) + self.epoch_loss.append(float(log["train_loss"])) def save_session_param( self, epoch, total_loss, n_iter_ep, valid_loss=None, valid_metrics=None diff --git a/torchhydro/trainers/train_utils.py b/torchhydro/trainers/train_utils.py index 4bc1594..570734b 100644 --- a/torchhydro/trainers/train_utils.py +++ b/torchhydro/trainers/train_utils.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-04-08 18:16:26 -LastEditTime: 2024-09-18 10:22:54 +LastEditTime: 2024-10-29 15:47:51 LastEditors: Wenyu Ouyang Description: Some basic functions for training FilePath: \torchhydro\torchhydro\trainers\train_utils.py @@ -25,7 +25,11 @@ from torch.utils.data import DataLoader from hydroutils.hydro_stat import stat_error -from hydroutils.hydro_file import get_lastest_file_in_a_dir, unserialize_json +from hydroutils.hydro_file import ( + get_lastest_file_in_a_dir, + unserialize_json, + get_latest_file_in_a_lst, +) from torchhydro.models.crits import GaussianLoss @@ -576,6 +580,28 @@ def read_pth_from_model_loader(model_loader, model_pth_dir): return weight_path +def get_lastest_logger_file_in_a_dir(dir_path): + """Get the last logger file in a directory + + Parameters + ---------- + dir_path : str + the directory + + Returns + ------- + str + the path of the logger file + """ + pattern = r"^\d{1,2}_[A-Za-z]+_\d{6}_\d{2}(AM|PM)\.json$" + pth_files_lst = [ + os.path.join(dir_path, file) + for file in os.listdir(dir_path) + if re.match(pattern, file) + ] + return get_latest_file_in_a_lst(pth_files_lst) + + def cellstates_when_inference(seq_first, data_cfgs, pred): """get cell states when inference""" cs_out = (