Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741

Merged
merged 29 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
982d01e
add paddle.take api
S-HuaBomb Jul 15, 2022
69b0a3e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Jul 15, 2022
b07c062
fix paddle.take
S-HuaBomb Jul 29, 2022
09d2836
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Jul 29, 2022
c8482f6
remove from pip import main
S-HuaBomb Jul 29, 2022
0665e50
test index out of range error
S-HuaBomb Aug 4, 2022
c5a9e16
test index out of range error and fix conflict
S-HuaBomb Aug 4, 2022
10b41c4
fix Examples
S-HuaBomb Aug 5, 2022
6852760
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 5, 2022
9649b87
fix Examples
S-HuaBomb Aug 5, 2022
ec1cfd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 5, 2022
6806a8f
add param mode to take api
S-HuaBomb Aug 22, 2022
27b6943
fix conflict ad merge
S-HuaBomb Aug 22, 2022
5d32c52
add example code
S-HuaBomb Aug 22, 2022
b35d831
fix test using np.testing.assert_allclose
S-HuaBomb Aug 23, 2022
cc2f4f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
c4161f2
add annotation
S-HuaBomb Aug 23, 2022
aaee858
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
ca2604f
fix typo
S-HuaBomb Aug 23, 2022
5979d5f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
7b3fc1d
fix 嵌套列表
S-HuaBomb Aug 24, 2022
668964d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 24, 2022
64b688a
fix Tensor,
S-HuaBomb Aug 24, 2022
cdd1080
fix docs warning
S-HuaBomb Aug 25, 2022
eca0483
fix conflict
S-HuaBomb Aug 25, 2022
4ca5c41
fix raise bug
S-HuaBomb Aug 27, 2022
9fb6896
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 27, 2022
7fd6c85
add test case for negative index out of range error
S-HuaBomb Aug 29, 2022
046ff44
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401

from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
Expand Down Expand Up @@ -654,4 +655,5 @@
'heaviside',
'tril_indices',
'sgn',
'take',
]
220 changes: 220 additions & 0 deletions python/paddle/fluid/tests/unittests/test_take.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard


class TestTakeAPI(unittest.TestCase):

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
self.index_dtype)

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

def test_static_graph(self):
paddle.enable_static()
startup_program = Program()
train_program = Program()
with program_guard(startup_program, train_program):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
out = paddle.take(x, index, mode=self.mode)

exe = fluid.Executor(self.place)
st_result = exe.run(fluid.default_main_program(),
feed={
'input': self.input_np,
'index': self.index_np
},
fetch_list=out)
np.testing.assert_allclose(
st_result[0],
np.take(self.input_np, self.index_np, mode=self.mode))

def test_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np)
dy_result = paddle.take(x, index, mode=self.mode)
np.testing.assert_allclose(
np.take(self.input_np, self.index_np, mode=self.mode),
dy_result.numpy())


class TestTakeInt32(TestTakeAPI):
"""Test take API with data type int32"""

def set_dtype(self):
self.input_dtype = 'int32'
self.index_dtype = 'int64'


class TestTakeInt64(TestTakeAPI):
"""Test take API with data type int64"""

def set_dtype(self):
self.input_dtype = 'int64'
self.index_dtype = 'int64'


class TestTakeFloat32(TestTakeAPI):
"""Test take API with data type float32"""

def set_dtype(self):
self.input_dtype = 'float32'
self.index_dtype = 'int64'


class TestTakeTypeError(TestTakeAPI):
"""Test take Type Error"""

def test_static_type_error(self):
"""Argument 'index' must be Tensor"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
self.assertRaises(TypeError, paddle.take, x, self.index_np,
self.mode)

def test_dygraph_type_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)

def test_static_dtype_error(self):
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype='float64',
shape=self.input_shape)
index = fluid.data(name='index',
dtype='float32',
shape=self.index_shape)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)

def test_dygraph_dtype_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='float32')
self.assertRaises(TypeError, paddle.take, x, index, self.mode)


class TestTakeModeRaise(unittest.TestCase):
"""Test take index out of range error"""

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

def test_static_index_error(self):
"""When the index is out of range,
an error is reported directly through `paddle.index_select`"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
self.assertRaises(ValueError, paddle.index_select, x, index)

def test_dygraph_index_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
self.assertRaises(ValueError, paddle.index_select, x, index)


class TestTakeModeWrap(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'wrap'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


class TestTakeModeClip(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'clip'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


if __name__ == "__main__":
unittest.main()
8 changes: 5 additions & 3 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
from .math import heaviside # noqa: F401
from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401

from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
Expand Down Expand Up @@ -280,8 +281,8 @@

from .einsum import einsum # noqa: F401

#this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa
# this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ # noqa
'matmul',
'dot',
'cov',
Expand Down Expand Up @@ -505,11 +506,12 @@
'put_along_axis_',
'exponential_',
'heaviside',
'take',
'bucketize',
'sgn',
]

#this list used in math_op_patch.py for magic_method bind
# this list used in math_op_patch.py for magic_method bind
magic_method_func = [
('__and__', 'bitwise_and'),
('__or__', 'bitwise_or'),
Expand Down
106 changes: 105 additions & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4756,7 +4756,6 @@ def frac(x, name=None):
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y})
return _elementwise_op(LayerHelper(op_type, **locals()))


def sgn(x, name=None):
"""
For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding
Expand Down Expand Up @@ -4797,3 +4796,108 @@ def sgn(x, name=None):
return paddle.as_complex(output)
else:
return paddle.sign(x)

def take(x, index, mode='raise', name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name of parameter needs to be consistent with rfc, input in rfc while x here, and mode is not in rfc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。

@S-HuaBomb 请先修改完RFC的评审意见吧。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rfc is still old now, should update and merge rfc first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the modified RFC PaddlePaddle/community#217 with instructions added

"""
Returns a new tensor with the elements of input tensor x at the given index.
The input tensor is treated as if it were viewed as a 1-D tensor.
The result takes the same shape as the index.

Args:
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64.
index (Tensor): An N-D Tensor, its data type should be int32, int64.
mode (str, optional): Specifies how out-of-bounds index will behave.
the candicates are ``'raise'`` | ``'wrap'`` | ``'clip'``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 , 分隔即可,下面的内容按照中文文档那边的意见统一改成列表吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, done.

If :attr:`mode` is ``'raise'``, raise an error (default);
If :attr:`mode` is ``'wrap'``, wrap around;
If :attr:`mode` is ``'clip'``, clip to the range.
``'clip'`` mode means that all indices that are too large are replaced by the index that
addresses the last element. Note that this disables indexing with negative numbers.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor: Tensor with the same shape as index, the data type is the same with input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor 后使用 ,,以避免解析出 Return Type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, done.


Examples:
.. code-block:: python

import paddle

x_int = paddle.arange(0, 12).reshape([3, 4])
x_float = x_int.astype(paddle.float64)

idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index
idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index
idx_err = paddle.arange(-2, 13).reshape([3, 5]) # index out of range

paddle.take(x_int, idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])

paddle.take(x_int, idx_neg)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 ],
# [1 , 2 , 3 ]])

paddle.take(x_float, idx_pos)
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
# [[4., 5., 6.],
# [7., 8., 9.]])

x_int.take(idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例可增加一个negative index和float类型的input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


paddle.take(x_int, idx_err, mode='wrap')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 0 ]])

paddle.take(x_int, idx_err, mode='clip')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[0 , 0 , 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 11]])

"""
if mode not in ['raise', 'wrap', 'clip']:
raise ValueError(
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode))

if paddle.in_dynamic_mode():
if not isinstance(index, (paddle.Tensor, Variable)):
raise TypeError(
"The type of 'index' must be Tensor, but got {}".format(type(index)))
if index.dtype not in [paddle.int32, paddle.int64]:
raise TypeError(
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
index.dtype))

else:
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index索引越界时需要报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


input_1d = x.flatten()
index_1d = index.flatten()
max_index = input_1d.shape[-1]

if mode == 'raise':
# This processing enables 'take' to handle negative indexes within the correct range.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以补充下注释,negative indexes可以enable,但越界的索引会在下面的index_select报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THX,Done

# Negative indexes can be enabled,
# but out-of-range indexes will report an error in the following paddle.index_select
index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d)
elif mode == 'wrap':
# The out of range indices are constrained by taking the remainder.
index_1d = paddle.where(index_1d < 0,
index_1d % max_index, index_1d)
index_1d = paddle.where(index_1d >= max_index,
index_1d % max_index, index_1d)
elif mode == 'clip':
# 'clip' mode disables indexing with negative numbers.
index_1d = clip(index_1d, 0, max_index - 1)

out = input_1d.index_select(index_1d).reshape(index.shape)

return out