From 76c1e1b18cce9adbc24417dc85acf131745f96fa Mon Sep 17 00:00:00 2001 From: Jan Date: Mon, 13 Jan 2025 18:16:47 +0100 Subject: [PATCH] fix: add nan handling in diagnostics (#1359) * add NaN and Inf check for diagnostics, adapt test. * feedback --- sbi/diagnostics/lc2st.py | 5 +++++ sbi/diagnostics/sbc.py | 8 +++++++- sbi/diagnostics/tarp.py | 8 +++++++- sbi/utils/diagnostics_utils.py | 24 ++++++++++++++++++++++ sbi/utils/sbiutils.py | 2 +- tests/inference_with_NaN_simulator_test.py | 8 ++++++++ 6 files changed, 52 insertions(+), 3 deletions(-) diff --git a/sbi/diagnostics/lc2st.py b/sbi/diagnostics/lc2st.py index a89f65ba2..795ca1d18 100644 --- a/sbi/diagnostics/lc2st.py +++ b/sbi/diagnostics/lc2st.py @@ -12,6 +12,8 @@ from torch import Tensor from tqdm import tqdm +from sbi.utils.diagnostics_utils import remove_nans_and_infs_in_x + class LC2ST: def __init__( @@ -83,6 +85,9 @@ def __init__( [2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py """ + # check inputs + thetas, xs = remove_nans_and_infs_in_x(thetas, xs) + assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], ( "Number of samples must match" ) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index 0017893c7..28c16f4ec 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -13,7 +13,10 @@ from sbi.inference import DirectPosterior from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior -from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch +from sbi.utils.diagnostics_utils import ( + get_posterior_samples_on_batch, + remove_nans_and_infs_in_x, +) from sbi.utils.metrics import c2st @@ -54,6 +57,9 @@ def run_sbc( ranks: ranks of the ground truth parameters under the inferred dap_samples: samples from the data averaged posterior. """ + + thetas, xs = remove_nans_and_infs_in_x(thetas, xs) + num_sbc_samples = thetas.shape[0] if num_sbc_samples < 100: diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 9cffb4c4c..4064db304 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -13,7 +13,10 @@ from torch import Tensor from sbi.inference.posteriors.base_posterior import NeuralPosterior -from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch +from sbi.utils.diagnostics_utils import ( + get_posterior_samples_on_batch, + remove_nans_and_infs_in_x, +) from sbi.utils.metrics import l2 @@ -61,6 +64,9 @@ def run_tarp( ecp: Expected coverage probability (``ecp``), see equation 4 of the paper alpha: credibility values, see equation 2 of the paper """ + + thetas, xs = remove_nans_and_infs_in_x(thetas, xs) + num_tarp_samples, dim_theta = thetas.shape posterior_samples = get_posterior_samples_on_batch( diff --git a/sbi/utils/diagnostics_utils.py b/sbi/utils/diagnostics_utils.py index d53cd536c..aee2aa09b 100644 --- a/sbi/utils/diagnostics_utils.py +++ b/sbi/utils/diagnostics_utils.py @@ -1,4 +1,5 @@ import warnings +from typing import Tuple import torch from joblib import Parallel, delayed @@ -9,6 +10,7 @@ from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.sbi_types import Shape +from sbi.utils import handle_invalid_x def get_posterior_samples_on_batch( @@ -91,3 +93,25 @@ def sample_fun( posterior_samples.shape[:2] }.""" return posterior_samples + + +def remove_nans_and_infs_in_x(thetas: Tensor, xs: Tensor) -> Tuple[Tensor, Tensor]: + """Remove NaNs and Infs entries in x from both the theta and x. + + Args: + thetas: Tensor of shape (num_samples, dim_parameters). + xs: Tensor of shape (num_samples, dim_observations). + + Returns: + Tuple of filtered thetas and xs, both of shape (num_valid_samples, ...). + """ + is_valid_x, num_nans, num_infs = handle_invalid_x(xs, exclude_invalid_x=True) + if num_nans > 0 or num_infs > 0: + warnings.warn( + f"Found {num_nans} NaNs and {num_infs} Infs in the data. " + f"These will be ignored below. Beware that only {is_valid_x.sum()} " + f"/ {len(xs)} samples are left.", + stacklevel=2, + ) + + return thetas[is_valid_x], xs[is_valid_x] diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 75d83dbbd..4da4c47bf 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -387,7 +387,7 @@ def nle_nre_apt_msg_on_invalid_x( if num_nans + num_infs > 0: if exclude_invalid_x: - logging.warn( + logging.warning( f"Found {num_nans} NaN simulations and {num_infs} Inf simulations." f"These will be discarded from training due to " f"`exclude_invalid_x=True`. Please be aware that this gives " diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index c2e1636cf..35d5c424e 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -7,6 +7,7 @@ from torch.distributions import MultivariateNormal from sbi import utils as utils +from sbi.diagnostics import run_sbc from sbi.inference import ( NPE_A, NPE_C, @@ -113,6 +114,13 @@ def linear_gaussian_nan( # Compute the c2st and assert it is near chance level of 0.5. check_c2st(samples, target_samples, alg=f"{method}") + # run sbc + num_sbc_samples = 100 + thetas = prior.sample((num_sbc_samples,)) + xs = simulator(thetas) + ranks, daps = run_sbc(thetas, xs, posterior, num_posterior_samples=1000) + assert torch.isfinite(ranks).all() + @pytest.mark.slow def test_inference_with_restriction_estimator():