Skip to content

Commit

Permalink
Integer indexing "wrap" mode now default
Browse files Browse the repository at this point in the history
- For performance reasons, "wrap" now clips positive indices and wraps negative indices
  • Loading branch information
ndgrigorian committed Mar 20, 2023
1 parent 4e06ba9 commit a1078c7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 87 deletions.
23 changes: 10 additions & 13 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@


def _get_indexing_mode(name):
modes = {"default": 0, "clip": 1, "wrap": 2}
modes = {"wrap": 0, "clip": 1}
try:
return modes[name]
except KeyError:
raise ValueError(
"`mode` must be `default`, `clip`, or `wrap`."
"Got `{}`.".format(name)
"`mode` must be `wrap` or `clip`." "Got `{}`.".format(name)
)


def take(x, indices, /, *, axis=None, mode="default"):
"""take(x, indices, axis=None, mode="default")
def take(x, indices, /, *, axis=None, mode="wrap"):
"""take(x, indices, axis=None, mode="wrap")
Takes elements from array along a given axis.
Expand All @@ -53,11 +52,10 @@ def take(x, indices, /, *, axis=None, mode="default"):
Default: `None`.
mode:
How out-of-bounds indices will be handled.
"default" - clamps indices to (-n <= i < n), then wraps
"wrap" - clamps indices to (-n <= i < n), then wraps
negative indices.
"clip" - clips indices to (0 <= i < n)
"wrap" - wraps both negative and positive indices.
Default: `"default"`.
Default: `"wrap"`.
Returns:
out: usm_ndarray
Expand Down Expand Up @@ -122,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="default"):
return res


def put(x, indices, vals, /, *, axis=None, mode="default"):
"""put(x, indices, vals, axis=None, mode="default")
def put(x, indices, vals, /, *, axis=None, mode="wrap"):
"""put(x, indices, vals, axis=None, mode="wrap")
Puts values of an array into another array
along a given axis.
Expand All @@ -142,11 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="default"):
Default: `None`.
mode:
How out-of-bounds indices will be handled.
"default" - clamps indices to (-n <= i < n), then wraps
"wrap" - clamps indices to (-n <= i < n), then wraps
negative indices.
"clip" - clips indices to (0 <= i < n)
"wrap" - wraps both negative and positive indices.
Default: `"default"`.
Default: `"wrap"`.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ namespace py = pybind11;
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;

class FancyIndex
class WrapIndex
{
public:
FancyIndex() = default;
WrapIndex() = default;

void operator()(py::ssize_t max_item, py::ssize_t &ind) const
{
Expand All @@ -73,20 +73,6 @@ class ClipIndex
}
};

class WrapIndex
{
public:
WrapIndex() = default;

void operator()(py::ssize_t max_item, py::ssize_t &ind) const
{
max_item = std::max<py::ssize_t>(max_item, 1);
ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item
: ind % max_item;
return;
}
};

template <typename ProjectorT, typename T, typename indT> class TakeFunctor
{
private:
Expand Down Expand Up @@ -361,22 +347,6 @@ sycl::event put_impl(sycl::queue q,
return put_ev;
}

template <typename fnT, typename T, typename indT> struct TakeFancyFactory
{
fnT get()
{
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value) {
fnT fn = take_impl<FancyIndex, T, indT>;
return fn;
}
else {
fnT fn = nullptr;
return fn;
}
}
};

template <typename fnT, typename T, typename indT> struct TakeWrapFactory
{
fnT get()
Expand Down Expand Up @@ -409,22 +379,6 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
}
};

template <typename fnT, typename T, typename indT> struct PutFancyFactory
{
fnT get()
{
if constexpr (std::is_integral<indT>::value &&
!std::is_same<indT, bool>::value) {
fnT fn = put_impl<FancyIndex, T, indT>;
return fn;
}
else {
fnT fn = nullptr;
return fn;
}
}
};

template <typename fnT, typename T, typename indT> struct PutWrapFactory
{
fnT get()
Expand Down
14 changes: 2 additions & 12 deletions dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@

#include "integer_advanced_indexing.hpp"

#define INDEXING_MODES 3
#define FANCY_MODE 0
#define INDEXING_MODES 2
#define WRAP_MODE 0
#define CLIP_MODE 1
#define WRAP_MODE 2

namespace dpctl
{
Expand Down Expand Up @@ -884,11 +883,6 @@ void init_advanced_indexing_dispatch_tables(void)
{
using namespace dpctl::tensor::detail;

using dpctl::tensor::kernels::indexing::TakeFancyFactory;
DispatchTableBuilder<take_fn_ptr_t, TakeFancyFactory, num_types>
dtb_takefancy;
dtb_takefancy.populate_dispatch_table(take_dispatch_table[FANCY_MODE]);

using dpctl::tensor::kernels::indexing::TakeClipFactory;
DispatchTableBuilder<take_fn_ptr_t, TakeClipFactory, num_types>
dtb_takeclip;
Expand All @@ -899,10 +893,6 @@ void init_advanced_indexing_dispatch_tables(void)
dtb_takewrap;
dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]);

using dpctl::tensor::kernels::indexing::PutFancyFactory;
DispatchTableBuilder<put_fn_ptr_t, PutFancyFactory, num_types> dtb_putfancy;
dtb_putfancy.populate_dispatch_table(put_dispatch_table[FANCY_MODE]);

using dpctl::tensor::kernels::indexing::PutClipFactory;
DispatchTableBuilder<put_fn_ptr_t, PutClipFactory, num_types> dtb_putclip;
dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]);
Expand Down
20 changes: 6 additions & 14 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,27 +898,19 @@ def test_integer_indexing_modes():
x = dpt.arange(5, sycl_queue=q)
x_np = dpt.asnumpy(x)

ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
ind_np = dpt.asnumpy(ind)
# wrapping negative indices
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)

# wrapping
res = dpt.take(x, ind, mode="wrap")
expected_arr = np.take(x_np, ind_np, mode="wrap")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")

assert (dpt.asnumpy(res) == expected_arr).all()

# clipping to 0 (disabling negative indices)
res = dpt.take(x, ind, mode="clip")
expected_arr = np.take(x_np, ind_np, mode="clip")

assert (dpt.asnumpy(res) == expected_arr).all()

# clipping to -n<=i<n,
# where n is the axis length
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)

res = dpt.take(x, ind, mode="default")
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="raise")
res = dpt.take(x, ind, mode="clip")
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")

assert (dpt.asnumpy(res) == expected_arr).all()

Expand Down

0 comments on commit a1078c7

Please sign in to comment.