Skip to content

Commit

Permalink
[MXNET-524] Broadcast like operator (apache#11820)
Browse files Browse the repository at this point in the history
* Registered the broadcast_like operator with GPU and CPU

Added appropriate shape inference

* Added python interface to ndarray and symbol

* Added python api documentation

* Fixed backward operation

* Added unit tests

* Fixed linting issues

* Added missing api doc
  • Loading branch information
ifeherva authored and szha committed Jul 20, 2018
1 parent a1a4c58 commit 3390095
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/ndarray/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ The `ndarray` package provides several classes:
NDArray.broadcast_to
NDArray.broadcast_axes
NDArray.broadcast_like
NDArray.tile
NDArray.pad
```
Expand Down Expand Up @@ -395,6 +396,7 @@ The `ndarray` package provides several classes:
broadcast_to
broadcast_axes
broadcast_like
repeat
tile
pad
Expand Down
2 changes: 2 additions & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ Composite multiple symbols into a new one by an operator.
Symbol.broadcast_to
Symbol.broadcast_axes
Symbol.broadcast_like
Symbol.tile
Symbol.pad
```
Expand Down Expand Up @@ -393,6 +394,7 @@ Composite multiple symbols into a new one by an operator.
broadcast_to
broadcast_axes
broadcast_like
repeat
tile
pad
Expand Down
38 changes: 38 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,44 @@ def broadcast_to(self, shape):
return op.broadcast_to(self, shape=tuple(shape))
# pylint: enable= undefined-variable

def broadcast_like(self, other):
"""Broadcasts the input array to the shape of other.
Broadcasting is only allowed on axes with size 1. The new shape cannot change
the number of dimensions.
For example, you could broadcast from shape (2, 1) to (2, 3), but not from
shape (2, 3) to (2, 3, 3).
Parameters
----------
other : NDArray
Array with shape of the desired array.
Returns
-------
NDArray
A NDArray with the desired shape that is not sharing data with this
array, even if the new shape is the same as ``self.shape``.
Examples
--------
>>> x = mx.nd.arange(0,3).reshape((1,3,1))
>>> x.asnumpy()
array([[[ 0.],
[ 1.],
[ 2.]]], dtype=float32)
>>> y = x.broadcast_like(mx.nd.ones((2,3,3)))
>>> y.asnumpy()
array([[[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]],
<BLANKLINE>
[[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]]], dtype=float32)
"""
return self.broadcast_to(other.shape)

def wait_to_read(self):
"""Waits until all previous write operations on the current array are finished.
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,14 @@ def broadcast_to(self, *args, **kwargs):
"""
return op.broadcast_to(self, *args, **kwargs)

def broadcast_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`broadcast_like`.
The arguments are the same as for :py:func:`broadcast_like`, with
this array as data.
"""
return op.broadcast_like(self, *args, **kwargs)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down
25 changes: 25 additions & 0 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,31 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
return true;
}

inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
TShape& lhs_shape = (*in_attrs)[0];
TShape& rhs_shape = (*in_attrs)[1];
TShape oshape = TShape(rhs_shape);
if (lhs_shape.ndim() == 0 || lhs_shape.ndim() == 0) return false;

CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
<< "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;

for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
if (rhs_shape[i] != 0) {
CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
<< "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
} else {
oshape[i] = lhs_shape[i];
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}

inline void BroadcastReduceShapeCompact(const TShape& big, const TShape& small,
TShape *new_big, TShape *new_small) {
index_t idim = std::max<index_t>(big.ndim(), MXNET_SPECIAL_MAX_NDIM);
Expand Down
39 changes: 39 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,45 @@ NNVM_REGISTER_OP(_broadcast_backward)
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});

NNVM_REGISTER_OP(broadcast_like)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
auto lhs = MakeNonlossGradNode("_broadcast_backward", n, ograds, {},
{{"keepdims", "true"}});
auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
{n->inputs[1]}, nullptr, &n);
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
return lhs;
})
.add_argument("lhs", "NDArray-or-Symbol", "First input.")
.add_argument("rhs", "NDArray-or-Symbol", "Second input.")
.describe(R"code(Broadcasts lhs to have the same shape as rhs.
Broadcasting is a mechanism that allows NDArrays to perform arithmetic operations
with arrays of different shapes efficiently without creating multiple copies of arrays.
Also see, `Broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ for more explanation.
Broadcasting is allowed on axes with size 1, such as from `(2,1,3,1)` to
`(2,8,3,9)`. Elements will be duplicated on the broadcasted axes.
For example::
broadcast_like([[1,2,3]], [[5,6,7],[7,8,9]]) = [[ 1., 2., 3.],
[ 1., 2., 3.]])
)code" ADD_FILELINE)
.set_attr<nnvm::FInferShape>("FInferShape", BroadcastLikeShape)
.set_attr<FCompute>("FCompute<cpu>", BroadcastCompute<cpu>);

NNVM_REGISTER_OP(norm)
MXNET_ADD_SPARSE_OP_ALIAS(norm)
.describe(R"code(Computes the norm on an NDArray.
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ NNVM_REGISTER_OP(broadcast_axis)
NNVM_REGISTER_OP(broadcast_to)
.set_attr<FCompute>("FCompute<gpu>", BroadcastCompute<gpu>);

NNVM_REGISTER_OP(broadcast_like)
.set_attr<FCompute>("FCompute<gpu>", BroadcastCompute<gpu>);

NNVM_REGISTER_OP(_broadcast_backward)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);

Expand Down
24 changes: 22 additions & 2 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,11 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes):
def test_broadcast():
sample_num = 1000
def test_broadcast_to():
for i in range(sample_num):
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
axes = []
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
Expand All @@ -532,7 +531,28 @@ def test_broadcast_to():
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8

def test_broadcast_like():
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
target = mx.nd.ones(shape=tuple(target_shape))
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray_ret = mx.nd.array(dat).broadcast_like(target)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8

test_broadcast_to()
test_broadcast_like()


@with_seed()
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,6 +2231,7 @@ def test_broadcast():
a = mx.symbol.Variable('a')
sym_bcast_axis = mx.symbol.broadcast_axis(a, axis=axis, size=size)
sym_bcast_to = mx.symbol.broadcast_to(a, shape=tuple(target_shape))
sym_bcast_like = mx.symbol.broadcast_like(a, sym_bcast_to)
def test_broadcasting_ele(sym_bcast):
dat_npy = np.random.rand(*shape)
groundtruth = dat_npy
Expand All @@ -2247,6 +2248,7 @@ def test_broadcasting_ele(sym_bcast):
assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4)
test_broadcasting_ele(sym_bcast_axis)
test_broadcasting_ele(sym_bcast_to)
test_broadcasting_ele(sym_bcast_like)


@with_seed()
Expand Down
25 changes: 23 additions & 2 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,11 @@ def test_sparse_nd_broadcast():
sample_num = 1000
# TODO(haibin) test with more than 2 dimensions
def test_broadcast_to(stype):
for i in range(sample_num):
for _ in range(sample_num):
ndim = 2
target_shape = np.random.randint(1, 11, size=ndim)
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
axes = []
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
Expand All @@ -397,9 +396,31 @@ def test_broadcast_to(stype):
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8

def test_broadcast_like(stype):
for _ in range(sample_num):
ndim = 2
target_shape = np.random.randint(1, 11, size=ndim)
target = mx.nd.ones(shape=tuple(target_shape))
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray = mx.nd.array(dat).tostype(stype)
ndarray_ret = ndarray.broadcast_like(target)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8

stypes = ['csr', 'row_sparse']
for stype in stypes:
test_broadcast_to(stype)
test_broadcast_like(stype)


@with_seed()
Expand Down

0 comments on commit 3390095

Please sign in to comment.