Skip to content

Commit

Permalink
[Hackathon No.18] 为 Paddle 新增 frexp API (#46401)
Browse files Browse the repository at this point in the history
* 之前的pr合并了大量错误代码,重新提交一份

* 之前的pr合并了大量错误代码,重新提交一份

* 修正格式问题

* 改回原来的格式

* 按照要求修改

* 按照要求修改格式

* 修复注释的问题

* 更新格式

* 测试自动格式化

* 修正英文注释

* fix docs build error

* pre-commit

* for docs build

* for docs build

* 修复mantissa计算错误的bug

* 修复误判exponent可能存在负数,导致计算量增加的情况

Co-authored-by: Ligoml <[email protected]>
  • Loading branch information
Zheng-Bicheng and Ligoml authored Sep 29, 2022
1 parent 9a1855f commit 1e2af54
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@
from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401

from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
Expand Down Expand Up @@ -386,7 +387,6 @@
os.environ.setdefault('runtime_include_dir', runtime_include_dir)

disable_static()

__all__ = [ # noqa
'iinfo',
'dtype',
Expand Down Expand Up @@ -667,4 +667,5 @@
'sgn',
'triu_indices',
'take',
'frexp',
]
94 changes: 94 additions & 0 deletions python/paddle/fluid/tests/unittests/test_frexp_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid


class TestFrexpAPI(unittest.TestCase):

def setUp(self):
np.random.seed(1024)
self.rtol = 1e-5
self.atol = 1e-8
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.set_input()

def set_input(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')

# 静态图单测
def test_static_api(self):
# 开启静态图模式
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
input_data = paddle.fluid.data('X', self.x_np.shape,
self.x_np.dtype)
out = paddle.frexp(input_data)
# 计算静态图结果
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])

out_ref = np.frexp(self.x_np)
# 对比静态图与 numpy 实现函数计算结果是否相同
for n, p in zip(out_ref, res):
np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol)

# 动态图单测
def test_dygraph_api(self):
# 关闭静态图模式
paddle.disable_static(self.place)
input_num = paddle.to_tensor(self.x_np)
# 测试动态图 tensor.frexp 和 paddle.tensor.math.frexp 计算结果
out1 = np.frexp(self.x_np)
out2 = paddle.frexp(input_num)
np.testing.assert_allclose(out1, out2, rtol=1e-05)

out1 = np.frexp(self.x_np)
out2 = input_num.frexp()
np.testing.assert_allclose(out1, out2, rtol=1e-05)
paddle.enable_static()


class TestSplitsFloat32Case1(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float32.
"""

def set_input(self):
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float32')


class TestSplitsFloat64Case1(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float64.
"""

def set_input(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float64')


class TestSplitsFloat64Case2(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float64.
"""

def set_input(self):
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float64')


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@
from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401

from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
Expand Down Expand Up @@ -517,6 +518,7 @@
'take',
'bucketize',
'sgn',
'frexp',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
49 changes: 49 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5108,3 +5108,52 @@ def take(x, index, mode='raise', name=None):
out = input_1d.index_select(index_1d).reshape(index.shape)

return out


def frexp(x, name=None):
"""
The function used to decompose a floating point number into mantissa and exponent.
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- mantissa (Tensor), A mantissa Tensor. The shape and data type of mantissa tensor and exponential tensor are
the same as those of input.
- exponent (Tensor), A exponent Tensor. The shape and data type of mantissa tensor and exponential tensor are
the same as those of input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1, 2, 3, 4]], dtype="float32")
print(paddle.tensor.math.frexp(x))
# (Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[0.50000000, 0.50000000, 0.75000000, 0.50000000]]),
# Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[1., 2., 2., 3.]]))
"""
if x.dtype not in [paddle.float32, paddle.float64]:
raise TypeError(
"The data type of input must be one of ['float32', 'float64'], but got {}"
.format(x.dtype))
input_x = paddle.abs(x)
exponent = paddle.floor(paddle.log2(input_x))
exponent = paddle.where(paddle.isinf(exponent),
paddle.full_like(exponent, 0), exponent)

# 0填充
mantissa = paddle.divide(input_x, 2**exponent)
# 计算exponent
exponent = paddle.where((mantissa >= 1),
paddle.add(exponent, paddle.ones_like(exponent)),
exponent)
mantissa = paddle.where((mantissa >= 1),
paddle.divide(mantissa,
2**paddle.ones_like(exponent)),
mantissa)

mantissa = paddle.where((x < 0), mantissa * -1, mantissa)
return mantissa, exponent

0 comments on commit 1e2af54

Please sign in to comment.