From 69db841f7d9e2dcaacacd9f1394c87afb48aedd4 Mon Sep 17 00:00:00 2001 From: ouyangwenyu Date: Sat, 28 Oct 2023 13:24:03 +0800 Subject: [PATCH] add explainer in config --- torchhydro/configs/config.py | 1 + torchhydro/trainers/deep_hydro.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchhydro/configs/config.py b/torchhydro/configs/config.py index b8e923d..20b894d 100644 --- a/torchhydro/configs/config.py +++ b/torchhydro/configs/config.py @@ -230,6 +230,7 @@ def default_config_file(): "metrics": ["NSE"], "fill_nan": "no", "test_epoch": 20, + "explainer": None, }, } diff --git a/torchhydro/trainers/deep_hydro.py b/torchhydro/trainers/deep_hydro.py index 7eec657..b97a218 100644 --- a/torchhydro/trainers/deep_hydro.py +++ b/torchhydro/trainers/deep_hydro.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 -LastEditTime: 2023-10-28 11:08:37 +LastEditTime: 2023-10-28 13:22:06 LastEditors: Wenyu Ouyang Description: HydroDL model class FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py @@ -361,7 +361,7 @@ def model_evaluate(self) -> Tuple[Dict, np.array, np.array]: ] # Finally, try to explain model behaviour using shap - is_shap = True + is_shap = self.cfgs["evaluation_cfgs"]["explainer"] == "shap" if is_shap: deep_explain_model_summary_plot(self.model, test_data) deep_explain_model_heatmap(self.model, test_data)