API名称 | paddle.multigammaln |
---|---|
提交作者 | 汪昕(GreatV) |
提交时间 | 2023-09-13 |
版本号 | V1.0 |
依赖飞桨版本 | develop |
文件名 | 20230913_api_design_for_multigammaln.md |
multigammaln
函数返回多元 gamma 函数的对数,有时也称为广义 gamma 函数。对于
其中
对上式取对数:
为 Paddle 新增 paddle.multigammaln
API,提供多元 gamma 函数的对数计算功能。所有元素必须大于 (d - 1) / 2,否则将会产生未定义行为。
为 Paddle 新增 paddle.multigammaln
API,提供多元 gamma 函数的对数计算功能。
对飞桨框架目前不支持此功能,可用其他API组合实现的此功能,代码如下;
import paddle
import numpy as np
a = paddle.to_tensor(23.5)
d = paddle.to_tensor(10)
pi = paddle.to_tensor(np.pi, dtype="float32")
out = (
d * (d - 1) / 4 * paddle.log(pi)
+ paddle.lgamma(a - 0.5 * paddle.arange(0, d, dtype="float32")).sum()
)
print(out)
在 Scipy 中使用的 API 格式如下:
scipy.special.multigammaln(a, d)
其中,a
为 ndarray
类型,是多元 gamma 函数的变量,d
为 int
类型,是多元 gamma 函数积分空间的维度。
实现的伪代码如下:
import numpy as np
from scipy.special import gammaln as loggam
def multigammaln(a, d):
res = (d * (d - 1) * 0.25) * np.log(np.pi)
res += np.sum(loggam([(a - (j - 1.0) / 2) for j in range(1, d + 1)]), axis=0)
return res
在 jax 中使用的 API 格式如下:
jax.scipy.special.multigammaln(a, d)
其中,a
为 ndarray
类型,是多元 gamma 函数的变量,d
为 int
类型,是多元 gamma 函数积分空间的维度。
实现代码如下:
def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d_ = promote_args_inexact("multigammaln", a, d)
constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_),
lax.sub(d_, _lax_const(a, 1))),
lax.log(_lax_const(a, np.pi)))
b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
axis=-1)
return res + constant
在 Pytorch 中使用的 API 格式如下:
torch.special.multigammaln(input, p, *, out=None)
其中,input
为 Tensor
类型,是多元 gamma 函数的变量,p
为 int
类型,是多元 gamma 函数的积分空间的维度。
实现代码如下:
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
c = 0.25 * p * (p - 1) * math.log(math.pi)
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
from scipy.special import multigammaln
a = 23.5
d = 10
out = multigammaln(a, d)
import torch
a = torch.empty(2, 3).uniform_(1, 2)
torch.special.multigammaln(a, 2)
上述框架从使用体验来说,差异不大,都是直接调用 API 即可。内部实现上也是大同小异。因此,可参考 PyTorch 的实现,为 Paddle 新增 paddle.multigammaln
API。
API设计为 paddle.multigammaln(x, p, name=None)
。其中,x
为 Tensor
类型,是多元 gamma 函数的变量,p
为 int
类型,是多元 gamma 函数的积分空间的维度, name
默认值为 None
,具体用法请参见 Name,一般无需设置。paddle.multigammaln_(x, p, name=None)
为 inplace 版本。Tensor.multigammaln(p)
为 Tensor 的方法版本。Tensor.multigammaln_(p)
为 Tensor 的 方法 inplace 版本。
参考 PyTorch 采用现有 PYTHON API 组合实现,实现位置为 Paddle repo python/paddle/tensor/math.py
目录。并在 python/paddle/tensor/init.py 中,添加 multigammaln
& multigammaln_
API,以支持 paddle.Tensor.multigammaln
& paddle.Tensor.multigammaln_
的调用方式。
可考虑一下场景:
- 当
x
为空张量,输出为空张量,且输出张量形状正确; - 结果一致性,和 SciPy 以及 PyTorch 结果的数值的一致性,
paddle.multigammaln(x, p)
,scipy.special.multigammaln(a, d)
和torch.special.multigammaln(input, p, *, out=None)
结果是否一致; - 异常测试,对于
x < (p - 1) / 2
,应该有友好的报错信息及异常反馈,需要有相关测试Case验证。
本 API 主要参考 PyTorch 实现,难度适中,工期上能满足要求。
为独立新增API,对其他模块没有影响。