Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SymbolicDistribution and Censored distributions #5169

Merged
merged 4 commits into from
Jan 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- With `pm.Data(..., mutable=True/False)`, or by using `pm.MutableData` vs. `pm.ConstantData` one can now create `TensorConstant` data variables. They can be more performant and compatible in situtations where a variable doesn't need to be changed via `pm.set_data()`. See [#5295](https://github.com/pymc-devs/pymc/pull/5295).
- New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`.
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
- ...


Expand Down
5 changes: 3 additions & 2 deletions docs/source/api/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ Distributions

distributions/continuous
distributions/discrete
distributions/logprob
distributions/multivariate
distributions/mixture
distributions/simulator
distributions/timeseries
distributions/censored
distributions/simulator
distributions/transforms
distributions/logprob
distributions/utilities
9 changes: 9 additions & 0 deletions docs/source/api/distributions/censored.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
********
Censored
********

.. currentmodule:: pymc
.. autosummary::
:toctree: generated

Censored
1 change: 1 addition & 0 deletions docs/source/api/distributions/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Distribution utilities
:toctree: generated/

Distribution
SymbolicDistribution
Discrete
Continuous
NoDistribution
Expand Down
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from pymc.distributions.bound import Bound
from pymc.distributions.censored import Censored
from pymc.distributions.continuous import (
AsymmetricLaplace,
Beta,
Expand Down Expand Up @@ -187,6 +188,7 @@
"Rice",
"Moyal",
"Simulator",
"Censored",
"CAR",
"PolyaGamma",
"logpt",
Expand Down
60 changes: 25 additions & 35 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pymc.distributions.logprob import logp
from pymc.distributions.shape_utils import to_tuple
from pymc.model import modelcontext
from pymc.util import check_dist_not_registered

__all__ = ["Bound"]

Expand Down Expand Up @@ -144,8 +145,9 @@ class Bound:

Parameters
----------
distribution: pymc distribution
Distribution to be transformed into a bounded distribution.
dist: PyMC unnamed distribution
Distribution to be transformed into a bounded distribution created via the
`.dist()` API.
lower: float or array like, optional
Lower bound of the distribution.
upper: float or array like, optional
Expand All @@ -156,15 +158,15 @@ class Bound:
.. code-block:: python

with pm.Model():
normal_dist = Normal.dist(mu=0.0, sigma=1.0, initval=-0.5)
negative_normal = pm.Bound(normal_dist, upper=0.0)
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
negative_normal = pm.Bound("negative_normal", normal_dist, upper=0.0)

"""

def __new__(
cls,
name,
distribution,
dist,
lower=None,
upper=None,
size=None,
Expand All @@ -174,7 +176,7 @@ def __new__(
**kwargs,
):

cls._argument_checks(distribution, **kwargs)
cls._argument_checks(dist, **kwargs)

if dims is not None:
model = modelcontext(None)
Expand All @@ -185,12 +187,12 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
distribution.tag.ignore_logprob = True
dist.tag.ignore_logprob = True

if isinstance(distribution.owner.op, Continuous):
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded(
name,
[distribution, lower, upper],
[dist, lower, upper],
initval=floatX(initval),
size=size,
shape=shape,
Expand All @@ -199,7 +201,7 @@ def __new__(
else:
res = _DiscreteBounded(
name,
[distribution, lower, upper],
[dist, lower, upper],
initval=intX(initval),
size=size,
shape=shape,
Expand All @@ -210,28 +212,28 @@ def __new__(
@classmethod
def dist(
cls,
distribution,
dist,
lower=None,
upper=None,
size=None,
shape=None,
**kwargs,
):

cls._argument_checks(distribution, **kwargs)
cls._argument_checks(dist, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
distribution.tag.ignore_logprob = True
if isinstance(distribution.owner.op, Continuous):
dist.tag.ignore_logprob = True
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded.dist(
[distribution, lower, upper],
[dist, lower, upper],
size=size,
shape=shape,
**kwargs,
)
res.tag.test_value = floatX(initval)
else:
res = _DiscreteBounded.dist(
[distribution, lower, upper],
[dist, lower, upper],
size=size,
shape=shape,
**kwargs,
Expand All @@ -240,7 +242,7 @@ def dist(
return res

@classmethod
def _argument_checks(cls, distribution, **kwargs):
def _argument_checks(cls, dist, **kwargs):
if "observed" in kwargs:
raise ValueError(
"Observed Bound distributions are not supported. "
Expand All @@ -249,34 +251,22 @@ def _argument_checks(cls, distribution, **kwargs):
"with the cumulative probability function."
)

if not isinstance(distribution, TensorVariable):
if not isinstance(dist, TensorVariable):
raise ValueError(
"Passing a distribution class to `Bound` is no longer supported.\n"
"Please pass the output of a distribution instantiated via the "
"`.dist()` API such as:\n"
'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
)

try:
model = modelcontext(None)
except TypeError:
pass
else:
if distribution in model.basic_RVs:
raise ValueError(
f"The distribution passed into `Bound` was already registered "
f"in the current model.\nYou should pass an unregistered "
f"(unnamed) distribution created via the `.dist()` API, such as:\n"
f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
)

if distribution.owner.op.ndim_supp != 0:
check_dist_not_registered(dist)

if dist.owner.op.ndim_supp != 0:
raise NotImplementedError("Bounding of MultiVariate RVs is not yet supported.")

if not isinstance(distribution.owner.op, (Discrete, Continuous)):
if not isinstance(dist.owner.op, (Discrete, Continuous)):
raise ValueError(
f"`distribution` {distribution} must be a Discrete or Continuous"
" distribution subclass"
f"`distribution` {dist} must be a Discrete or Continuous" " distribution subclass"
)

@classmethod
Expand Down
146 changes: 146 additions & 0 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara.tensor as at
import numpy as np

from aesara.scalar import Clip
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.distributions.distribution import SymbolicDistribution, _get_moment
from pymc.util import check_dist_not_registered


class Censored(SymbolicDistribution):
r"""
Censored distribution

The pdf of a censored distribution is

.. math::

\begin{cases}
0 & \text{for } x < lower, \\
\text{CDF}(lower, dist) & \text{for } x = lower, \\
\text{PDF}(x, dist) & \text{for } lower < x < upper, \\
1-\text{CDF}(upper, dist) & \text {for} x = upper, \\
0 & \text{for } x > upper,
\end{cases}


Parameters
----------
dist: PyMC unnamed distribution
PyMC distribution created via the `.dist()` API, which will be censored. This
distribution must be univariate and have a logcdf method implemented.
lower: float or None
Lower (left) censoring point. If `None` the distribution will not be left censored
upper: float or None
Upper (right) censoring point. If `None`, the distribution will not be right censored.


Examples
--------
.. code-block:: python

with pm.Model():
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1)
"""

@classmethod
def dist(cls, dist, lower, upper, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
raise ValueError(
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError(
"Censoring of multivariate distributions has not been implemented yet"
)
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
if lower is None:
lower = at.constant(-np.inf)
if upper is None:
upper = at.constant(np.inf)

# Censoring is achieved by clipping the base distribution between lower and upper
rv_out = at.clip(dist, lower, upper)

# Reference nodes to facilitate identification in other classmethods, without
# worring about possible dimshuffles
rv_out.tag.dist = dist
rv_out.tag.lower = lower
rv_out.tag.upper = upper

if size is not None:
rv_out = cls.change_size(rv_out, size)
if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)

return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size):
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
rng, old_size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
olg_rng, size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def graph_rvs(cls, rv):
return (rv.tag.dist,)


@_get_moment.register(Clip)
def get_moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
at.eq(lower, -np.inf),
at.switch(
at.isinf(upper),
# lower = -inf, upper = inf
0,
# lower = -inf, upper = x
upper - 1,
),
at.switch(
at.eq(upper, np.inf),
# lower = x, upper = inf
lower + 1,
# lower = x, upper = x
(lower + upper) / 2,
),
)
moment = at.full_like(dist, moment)
return moment
Loading