diff --git a/torchhydro/configs/config.py b/torchhydro/configs/config.py index b8e923d..20b894d 100644 --- a/torchhydro/configs/config.py +++ b/torchhydro/configs/config.py @@ -230,6 +230,7 @@ def default_config_file(): "metrics": ["NSE"], "fill_nan": "no", "test_epoch": 20, + "explainer": None, }, } diff --git a/torchhydro/trainers/deep_hydro.py b/torchhydro/trainers/deep_hydro.py index 7eec657..b97a218 100644 --- a/torchhydro/trainers/deep_hydro.py +++ b/torchhydro/trainers/deep_hydro.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 -LastEditTime: 2023-10-28 11:08:37 +LastEditTime: 2023-10-28 13:22:06 LastEditors: Wenyu Ouyang Description: HydroDL model class FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py @@ -361,7 +361,7 @@ def model_evaluate(self) -> Tuple[Dict, np.array, np.array]: ] # Finally, try to explain model behaviour using shap - is_shap = True + is_shap = self.cfgs["evaluation_cfgs"]["explainer"] == "shap" if is_shap: deep_explain_model_summary_plot(self.model, test_data) deep_explain_model_heatmap(self.model, test_data)