Skip to content

Commit

Permalink
[ Hackathon 3rd No.2 ] add paddle.iinfo (#45321)
Browse files Browse the repository at this point in the history
  • Loading branch information
OccupyMars2025 authored Sep 8, 2022
1 parent a642365 commit 40a0a46
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 2 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/eager/grad_node_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class GradNodeBase {
// Gradient Hooks
// Customer may register a list of hooks which will be called in order during
// backward
// Each entry consists one pair of
// Each entry consists of one pair of
// <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
std::map<int64_t,
std::tuple<
Expand Down
62 changes: 62 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -346,6 +347,52 @@ bool IsCompiledWithDIST() {
#endif
}

struct iinfo {
int64_t min, max;
int bits;
std::string dtype;

explicit iinfo(const framework::proto::VarType::Type &type) {
switch (type) {
case framework::proto::VarType::INT16:
min = std::numeric_limits<int16_t>::min();
max = std::numeric_limits<int16_t>::max();
bits = 16;
dtype = "int16";
break;
case framework::proto::VarType::INT32:
min = std::numeric_limits<int32_t>::min();
max = std::numeric_limits<int32_t>::max();
bits = 32;
dtype = "int32";
break;
case framework::proto::VarType::INT64:
min = std::numeric_limits<int64_t>::min();
max = std::numeric_limits<int64_t>::max();
bits = 64;
dtype = "int64";
break;
case framework::proto::VarType::INT8:
min = std::numeric_limits<int8_t>::min();
max = std::numeric_limits<int8_t>::max();
bits = 8;
dtype = "int8";
break;
case framework::proto::VarType::UINT8:
min = std::numeric_limits<uint8_t>::min();
max = std::numeric_limits<uint8_t>::max();
bits = 8;
dtype = "uint8";
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"the argument of paddle.iinfo can only be paddle.int8, "
"paddle.int16, paddle.int32, paddle.int64, or paddle.uint8"));
break;
}
}
};

static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
Expand Down Expand Up @@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) {

BindException(&m);

py::class_<iinfo>(m, "iinfo")
.def(py::init<const framework::proto::VarType::Type &>())
.def_readonly("min", &iinfo::min)
.def_readonly("max", &iinfo::max)
.def_readonly("bits", &iinfo::bits)
.def_readonly("dtype", &iinfo::dtype)
.def("__repr__", [](const iinfo &a) {
std::ostringstream oss;
oss << "paddle.iinfo(min=" << a.min;
oss << ", max=" << a.max;
oss << ", bits=" << a.bits;
oss << ", dtype=" << a.dtype << ")";
return oss.str();
});

m.def("set_num_threads", &platform::SetNumThreads);

m.def("disable_signal_handler", &DisableSignalHandler);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyGuard # noqa: F401

from .framework.dtype import iinfo # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401
from .framework.dtype import int8 # noqa: F401
Expand Down Expand Up @@ -386,6 +387,7 @@
disable_static()

__all__ = [ # noqa
'iinfo',
'dtype',
'uint8',
'int8',
Expand Down
45 changes: 45 additions & 0 deletions python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import unittest
import numpy as np


class TestIInfoAndFInfoAPI(unittest.TestCase):

def test_invalid_input(self):
for dtype in [
paddle.float16, paddle.float32, paddle.float64, paddle.bfloat16,
paddle.complex64, paddle.complex128, paddle.bool
]:
with self.assertRaises(ValueError):
_ = paddle.iinfo(dtype)

def test_iinfo(self):
for paddle_dtype, np_dtype in [(paddle.int64, np.int64),
(paddle.int32, np.int32),
(paddle.int16, np.int16),
(paddle.int8, np.int8),
(paddle.uint8, np.uint8)]:
xinfo = paddle.iinfo(paddle_dtype)
xninfo = np.iinfo(np_dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)


if __name__ == '__main__':
unittest.main()
36 changes: 35 additions & 1 deletion python/paddle/framework/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from ..fluid.core import VarDesc
from ..fluid.core import iinfo as core_iinfo

dtype = VarDesc.VarType
dtype.__qualname__ = "dtype"
Expand All @@ -34,4 +35,37 @@

bool = VarDesc.VarType.BOOL

__all__ = []

def iinfo(dtype):
"""
paddle.iinfo is a function that returns an object that represents the numerical properties of
an integer paddle.dtype.
This is similar to `numpy.iinfo <https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html#numpy-iinfo>`_.
Args:
dtype(paddle.dtype): One of paddle.uint8, paddle.int8, paddle.int16, paddle.int32, and paddle.int64.
Returns:
An iinfo object, which has the following 4 attributes:
- min: int, The smallest representable integer number.
- max: int, The largest representable integer number.
- bits: int, The number of bits occupied by the type.
- dtype: str, The string name of the argument dtype.
Examples:
.. code-block:: python
import paddle
iinfo_uint8 = paddle.iinfo(paddle.uint8)
print(iinfo_uint8)
# paddle.iinfo(min=0, max=255, bits=8, dtype=uint8)
print(iinfo_uint8.min) # 0
print(iinfo_uint8.max) # 255
print(iinfo_uint8.bits) # 8
print(iinfo_uint8.dtype) # uint8
"""
return core_iinfo(dtype)

0 comments on commit 40a0a46

Please sign in to comment.