From 55da81d5e1302e5f01c71d7902ee68d01005c7cd Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 9 Oct 2017 23:26:03 +0200 Subject: [PATCH] Add LogExpM1 transformation (#2601) * Add softplus transformation Softplus transformation (from non-negative to reals) might be more numerically stable (see Fig. 9 in Kucukelbir et al. 2017). * add test * Change name and implement a more numerically stable logexpm1 See https://github.com/tensorflow/tensorflow/blob/0b0d3c12ace80381f4a44365d30275a9a262609b/tensorflow/python/ops/distributions/util.py#L1009 for the derivation * change default transformation for PositiveContinuous * Revert "change default transformation for PositiveContinuous" This reverts commit 8bc036ce4ab87d9cd55d1cd00a01646da241e1a0. * name change --- pymc3/distributions/transforms.py | 27 +++++++++++++++++++++++---- pymc3/tests/test_transforms.py | 10 ++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 19254ac2b17..0a70dfea1ee 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -8,7 +8,7 @@ from .distribution import draw_values import numpy as np -__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', +__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'log_exp_m1', 'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking'] @@ -105,12 +105,31 @@ def jacobian_det(self, x): log = Log() +class LogExpM1(ElemwiseTransform): + name = "log_exp_m1" + + def backward(self, x): + return tt.nnet.softplus(x) + + def forward(self, x): + """Inverse operation of softplus + y = Log(Exp(x) - 1) + = Log(1 - Exp(-x)) + x + """ + return tt.log(1.-tt.exp(-x)) + x + + def forward_val(self, x, point=None): + return self.forward(x) + + def jacobian_det(self, x): + return -tt.nnet.softplus(-x) + +log_exp_m1 = LogExpM1() + + class LogOdds(ElemwiseTransform): name = "logodds" - def __init__(self): - pass - def backward(self, x): return invlogit(x, 0.0) diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index fc19325c0f8..9c7b05d542e 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -104,6 +104,16 @@ def test_log(): close_to_logical(vals > 0, True, tol) +def test_log_exp_m1(): + check_transform_identity(tr.log_exp_m1, Rplusbig) + check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True) + check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), + tt.dvector, [0, 0], elemwise=True) + + vals = get_values(tr.log_exp_m1) + close_to_logical(vals > 0, True, tol) + + def test_logodds(): check_transform_identity(tr.logodds, Unit) check_jacobian_det(tr.logodds, Unit, elemwise=True)