Skip to content

Commit

Permalink
Raise when number of dims does not match var.ndim
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 25, 2024
1 parent be770a6 commit fb11f01
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 15 deletions.
11 changes: 2 additions & 9 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,13 +1185,6 @@ def logp(value, p):
)


class _OrderedLogistic(Categorical):
r"""
Underlying class for ordered logistic distributions.
See docs for the OrderedLogistic wrapper class for more details on how to use it in models.
"""


class OrderedLogistic:
R"""Ordered Logistic distribution.
Expand Down Expand Up @@ -1263,7 +1256,7 @@ class OrderedLogistic:
def __new__(cls, name, eta, cutpoints, compute_p=True, **kwargs):
p = cls.compute_p(eta, cutpoints)
if compute_p:
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
p = pm.Deterministic(f"{name}_probs", p)
out_rv = Categorical(name, p=p, **kwargs)
return out_rv

Expand Down Expand Up @@ -1367,7 +1360,7 @@ class OrderedProbit:
def __new__(cls, name, eta, cutpoints, sigma=1, compute_p=True, **kwargs):
p = cls.compute_p(eta, cutpoints, sigma)
if compute_p:
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
p = pm.Deterministic(f"{name}_probs", p)
out_rv = Categorical(name, p=p, **kwargs)
return out_rv

Expand Down
5 changes: 5 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,11 @@ def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
if any(var.name == dim for dim in dims if dim is not None):
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
# This check implicitly states that only vars with this attribute can have dims
if var.ndim != len(dims):
raise ValueError(
f"{var} has {var.ndim} dims but {len(dims)} dim labels were provided."
)
self.named_vars_to_dims[var.name] = dims

self.named_vars[var.name] = var
Expand Down
16 changes: 11 additions & 5 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,19 +897,25 @@ def test_shape_inputs(self, eta, cutpoints, expected):
assert p_shape == expected

def test_compute_p(self):
with pm.Model() as m:
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0)
pm.OrderedLogistic("ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False)
with pm.Model(coords={"test_dim": [0]}) as m:
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0, dims="test_dim")
pm.OrderedLogistic(
"ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False, dims="test_dim"
)
assert len(m.deterministics) == 1

x = pm.OrderedLogistic.dist(cutpoints=np.array([-2, 0, 2]), eta=0)
assert isinstance(x, TensorVariable)

# Test it works with auto-imputation
with pm.Model() as m:
with pm.Model(coords={"test_dim": [0, 1, 2]}) as m:
with pytest.warns(ImputationWarning):
pm.OrderedLogistic(
"ol", cutpoints=np.array([-2, 0, 2]), eta=0, observed=[0, np.nan, 1]
"ol",
cutpoints=np.array([[-2, 0, 2]]),
eta=0,
observed=[0, np.nan, 1],
dims=["test_dim"],
)
assert len(m.deterministics) == 2 # One from the auto-imputation, the other from compute_p

Expand Down
15 changes: 14 additions & 1 deletion tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,20 @@ def test_add_named_variable_checks_dim_name(self):
rv2.name = "yumyum"
pmodel.add_named_variable(rv2, dims=("nomnom", None))

def test_dims_type_check(self):
def test_add_named_variable_checks_number_of_dims(self):
match = "dim labels were provided"
with pm.Model(coords={"bad": range(6)}) as m:
with pytest.raises(ValueError, match=match):
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="a"), dims=("bad",))

with pytest.raises(ValueError, match=match):
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="b"), dims="bad")

# For variables without ndim we can't check
m.add_named_variable(pytensor.as_symbolic(None, name="c"), dims=("bad",))
assert m.named_vars_to_dims == {"c": ("bad",)}

def test_rv_dims_type_check(self):
with pm.Model(coords={"a": range(5)}) as m:
with pytest.raises(TypeError, match="Dims must be string"):
x = pm.Normal("x", shape=(10, 5), dims=(None, "a"))
Expand Down

0 comments on commit fb11f01

Please sign in to comment.