Skip to content

Commit

Permalink
fix: add nan handling in diagnostics (#1359)
Browse files Browse the repository at this point in the history
* add NaN and Inf check for diagnostics, adapt test.

* feedback
  • Loading branch information
janfb authored Jan 13, 2025
1 parent e1305b9 commit 76c1e1b
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 3 deletions.
5 changes: 5 additions & 0 deletions sbi/diagnostics/lc2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch import Tensor
from tqdm import tqdm

from sbi.utils.diagnostics_utils import remove_nans_and_infs_in_x


class LC2ST:
def __init__(
Expand Down Expand Up @@ -83,6 +85,9 @@ def __init__(
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
"""

# check inputs
thetas, xs = remove_nans_and_infs_in_x(thetas, xs)

assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], (
"Number of samples must match"
)
Expand Down
8 changes: 7 additions & 1 deletion sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from sbi.inference import DirectPosterior
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
from sbi.utils.diagnostics_utils import (
get_posterior_samples_on_batch,
remove_nans_and_infs_in_x,
)
from sbi.utils.metrics import c2st


Expand Down Expand Up @@ -54,6 +57,9 @@ def run_sbc(
ranks: ranks of the ground truth parameters under the inferred
dap_samples: samples from the data averaged posterior.
"""

thetas, xs = remove_nans_and_infs_in_x(thetas, xs)

num_sbc_samples = thetas.shape[0]

if num_sbc_samples < 100:
Expand Down
8 changes: 7 additions & 1 deletion sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from torch import Tensor

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
from sbi.utils.diagnostics_utils import (
get_posterior_samples_on_batch,
remove_nans_and_infs_in_x,
)
from sbi.utils.metrics import l2


Expand Down Expand Up @@ -61,6 +64,9 @@ def run_tarp(
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
alpha: credibility values, see equation 2 of the paper
"""

thetas, xs = remove_nans_and_infs_in_x(thetas, xs)

num_tarp_samples, dim_theta = thetas.shape

posterior_samples = get_posterior_samples_on_batch(
Expand Down
24 changes: 24 additions & 0 deletions sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Tuple

import torch
from joblib import Parallel, delayed
Expand All @@ -9,6 +10,7 @@
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import Shape
from sbi.utils import handle_invalid_x


def get_posterior_samples_on_batch(
Expand Down Expand Up @@ -91,3 +93,25 @@ def sample_fun(
posterior_samples.shape[:2]
}."""
return posterior_samples


def remove_nans_and_infs_in_x(thetas: Tensor, xs: Tensor) -> Tuple[Tensor, Tensor]:
"""Remove NaNs and Infs entries in x from both the theta and x.
Args:
thetas: Tensor of shape (num_samples, dim_parameters).
xs: Tensor of shape (num_samples, dim_observations).
Returns:
Tuple of filtered thetas and xs, both of shape (num_valid_samples, ...).
"""
is_valid_x, num_nans, num_infs = handle_invalid_x(xs, exclude_invalid_x=True)
if num_nans > 0 or num_infs > 0:
warnings.warn(
f"Found {num_nans} NaNs and {num_infs} Infs in the data. "
f"These will be ignored below. Beware that only {is_valid_x.sum()} "
f"/ {len(xs)} samples are left.",
stacklevel=2,
)

return thetas[is_valid_x], xs[is_valid_x]
2 changes: 1 addition & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def nle_nre_apt_msg_on_invalid_x(

if num_nans + num_infs > 0:
if exclude_invalid_x:
logging.warn(
logging.warning(
f"Found {num_nans} NaN simulations and {num_infs} Inf simulations."
f"These will be discarded from training due to "
f"`exclude_invalid_x=True`. Please be aware that this gives "
Expand Down
8 changes: 8 additions & 0 deletions tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributions import MultivariateNormal

from sbi import utils as utils
from sbi.diagnostics import run_sbc
from sbi.inference import (
NPE_A,
NPE_C,
Expand Down Expand Up @@ -113,6 +114,13 @@ def linear_gaussian_nan(
# Compute the c2st and assert it is near chance level of 0.5.
check_c2st(samples, target_samples, alg=f"{method}")

# run sbc
num_sbc_samples = 100
thetas = prior.sample((num_sbc_samples,))
xs = simulator(thetas)
ranks, daps = run_sbc(thetas, xs, posterior, num_posterior_samples=1000)
assert torch.isfinite(ranks).all()


@pytest.mark.slow
def test_inference_with_restriction_estimator():
Expand Down

0 comments on commit 76c1e1b

Please sign in to comment.