Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.9】add multigammaln api -part #57599

Merged
merged 17 commits into from
Nov 17, 2023
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@
from .tensor.math import square_ # noqa: F401
from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401
from .tensor.math import multigammaln # noqa: F401
from .tensor.math import multigammaln_ # noqa: F401
from .tensor.math import nan_to_num # noqa: F401
from .tensor.math import nan_to_num_ # noqa: F401
from .tensor.math import nansum # noqa: F401
Expand Down Expand Up @@ -812,6 +814,8 @@
'renorm_',
'take_along_axis',
'put_along_axis',
'multigammaln',
'multigammaln_',
'nan_to_num',
'nan_to_num_',
'heaviside',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@
from .math import square # noqa: F401
from .math import stanh # noqa: F401
from .math import sum # noqa: F401
from .math import multigammaln # noqa: F401
from .math import multigammaln_ # noqa: F401
from .math import nan_to_num # noqa: F401
from .math import nan_to_num_ # noqa: F401
from .math import nansum # noqa: F401
Expand Down Expand Up @@ -448,6 +450,8 @@
'square',
'stanh',
'sum',
'multigammaln',
'multigammaln_',
GreatV marked this conversation as resolved.
Show resolved Hide resolved
'nan_to_num',
'nan_to_num_',
'nansum',
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5052,6 +5052,46 @@ def lgamma_(x, name=None):
return _C_ops.lgamma_(x)


def multigammaln(x, p, name=None):
GreatV marked this conversation as resolved.
Show resolved Hide resolved
"""
This function computes the log of multivariate gamma, also sometimes called the generalized gamma.

Args:
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16.
p (int): The dimension of the space of integration.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
out (Tensor): The values of the log multivariate gamma at the given tensor x.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([2.5, 3.5, 4, 6.5, 7.8, 10.23, 34.25])
>>> p = 2
>>> out = paddle.multigammaln(x, p)
>>> out
GreatV marked this conversation as resolved.
Show resolved Hide resolved
Tensor(shape=[7], dtype=float64, place=Place(cpu), stop_gradient=True,
GreatV marked this conversation as resolved.
Show resolved Hide resolved
[0.85704780 , 2.46648574 , 3.56509781 , 11.02241898 , 15.84497738 ,
26.09257938 , 170.68316451])
"""

c = 0.25 * p * (p - 1) * np.log(np.pi)
b = 0.5 * paddle.arange(start=(1 - p), end=1, step=1, dtype=x.dtype)
return paddle.sum(paddle.lgamma(x.unsqueeze(-1) + b), axis=-1) + c


@inplace_apis_in_dygraph_only
def multigammaln_(x, p, name=None):
r"""
Inplace version of ``multigammaln`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_multigammaln`.
"""
return x.multigammaln_(p, name=name)


def neg(x, name=None):
"""
This function computes the negative of the Tensor elementwisely.
Expand Down
68 changes: 68 additions & 0 deletions test/legacy_test/test_multigammaln.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
GreatV marked this conversation as resolved.
Show resolved Hide resolved
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.

import unittest

import numpy as np
from scipy import special

import paddle


def ref_multigammaln(x, p):
return special.multigammaln(x, p)


class TestMultigammalnAPI(unittest.TestCase):
def setUp(self):
np.random.seed(1024)
self.x = np.random.rand(10, 20).astype('float32') + 1.0
self.p = 2
self.init_input()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def init_input(self):
pass

def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype)
out = paddle.multigammaln(x, self.p)
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={
'x': self.x,
},
fetch_list=[out],
)
out_ref = ref_multigammaln(self.x, self.p)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-6, atol=1e-6)
GreatV marked this conversation as resolved.
Show resolved Hide resolved

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
out = paddle.multigammaln(x, self.p)
out_ref = ref_multigammaln(self.x, self.p)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-6, atol=1e-6)
paddle.enable_static()


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
GreatV marked this conversation as resolved.
Show resolved Hide resolved