Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse dpctl.tensor.where for dpnp.where #1380

Merged
merged 4 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions dpnp/backend/include/dpnp_iface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,57 +1683,6 @@ INP_DLLEXPORT void dpnp_var_c(void* array,
size_t naxis,
size_t ddof);

/**
* @ingroup BACKEND_API
* @brief Implementation of where function
*
* @param [in] q_ref Reference to SYCL queue.
* @param [out] result_out Output array.
* @param [in] result_size Size of output array.
* @param [in] result_ndim Number of output array dimensions.
* @param [in] result_shape Shape of output array.
* @param [in] result_strides Strides of output array.
* @param [in] condition_in Condition array.
* @param [in] condition_size Size of condition array.
* @param [in] condition_ndim Number of condition array dimensions.
* @param [in] condition_shape Shape of condition array.
* @param [in] condition_strides Strides of condition array.
* @param [in] input1_in First input array.
* @param [in] input1_size Size of first input array.
* @param [in] input1_ndim Number of first input array dimensions.
* @param [in] input1_shape Shape of first input array.
* @param [in] input1_strides Strides of first input array.
* @param [in] input2_in Second input array.
* @param [in] input2_size Size of second input array.
* @param [in] input2_ndim Number of second input array dimensions.
* @param [in] input2_shape Shape of second input array.
* @param [in] input2_strides Strides of second input array.
* @param [in] dep_event_vec_ref Reference to vector of SYCL events.
*/
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
INP_DLLEXPORT DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref,
void* result_out,
const size_t result_size,
const size_t result_ndim,
const shape_elem_type* result_shape,
const shape_elem_type* result_strides,
const void* condition_in,
const size_t condition_size,
const size_t condition_ndim,
const shape_elem_type* condition_shape,
const shape_elem_type* condition_strides,
const void* input1_in,
const size_t input1_size,
const size_t input1_ndim,
const shape_elem_type* input1_shape,
const shape_elem_type* input1_strides,
const void* input2_in,
const size_t input2_size,
const size_t input2_ndim,
const shape_elem_type* input2_shape,
const shape_elem_type* input2_strides,
const DPCTLEventVectorRef dep_event_vec_ref);

/**
* @ingroup BACKEND_API
* @brief Implementation of invert function
Expand Down
1 change: 0 additions & 1 deletion dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ enum class DPNPFuncName : size_t
DPNP_FN_VANDER_EXT, /**< Used in numpy.vander() impl, requires extra parameters */
DPNP_FN_VAR, /**< Used in numpy.var() impl */
DPNP_FN_VAR_EXT, /**< Used in numpy.var() impl, requires extra parameters */
DPNP_FN_WHERE_EXT, /**< Used in numpy.where() impl, requires extra parameters */
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
DPNP_FN_LAST, /**< The latest element of the enumeration */
Expand Down
255 changes: 0 additions & 255 deletions dpnp/backend/kernels/dpnp_krnl_searching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#include <dpnp_iface.hpp>
#include "dpnp_fptr.hpp"
#include "dpnp_iterator.hpp"
#include "dpnpc_memory_adapter.hpp"
#include "queue_sycl.hpp"

Expand Down Expand Up @@ -140,258 +139,6 @@ DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
size_t,
const DPCTLEventVectorRef) = dpnp_argmin_c<_DataType, _idx_DataType>;


template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
class dpnp_where_c_broadcast_kernel;

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
class dpnp_where_c_strides_kernel;

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
class dpnp_where_c_kernel;

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref,
void* result_out,
const size_t result_size,
const size_t result_ndim,
const shape_elem_type* result_shape,
const shape_elem_type* result_strides,
const void* condition_in,
const size_t condition_size,
const size_t condition_ndim,
const shape_elem_type* condition_shape,
const shape_elem_type* condition_strides,
const void* input1_in,
const size_t input1_size,
const size_t input1_ndim,
const shape_elem_type* input1_shape,
const shape_elem_type* input1_strides,
const void* input2_in,
const size_t input2_size,
const size_t input2_ndim,
const shape_elem_type* input2_shape,
const shape_elem_type* input2_strides,
const DPCTLEventVectorRef dep_event_vec_ref)
{
/* avoid warning unused variable*/
(void)dep_event_vec_ref;

DPCTLSyclEventRef event_ref = nullptr;

if (!condition_size || !input1_size || !input2_size)
{
return event_ref;
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));

bool* condition_data = static_cast<bool*>(const_cast<void*>(condition_in));
_DataType_input1* input1_data = static_cast<_DataType_input1*>(const_cast<void*>(input1_in));
_DataType_input2* input2_data = static_cast<_DataType_input2*>(const_cast<void*>(input2_in));
_DataType_output* result = static_cast<_DataType_output*>(result_out);

bool use_broadcasting = !array_equal(input1_shape, input1_ndim, input2_shape, input2_ndim);
use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input1_shape, input1_ndim);
use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input2_shape, input2_ndim);

shape_elem_type* condition_shape_offsets = new shape_elem_type[condition_ndim];

get_shape_offsets_inkernel(condition_shape, condition_ndim, condition_shape_offsets);
bool use_strides = !array_equal(condition_strides, condition_ndim, condition_shape_offsets, condition_ndim);
delete[] condition_shape_offsets;

shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim];

get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets);
use_strides = use_strides || !array_equal(input1_strides, input1_ndim, input1_shape_offsets, input1_ndim);
delete[] input1_shape_offsets;

shape_elem_type* input2_shape_offsets = new shape_elem_type[input2_ndim];

get_shape_offsets_inkernel(input2_shape, input2_ndim, input2_shape_offsets);
use_strides = use_strides || !array_equal(input2_strides, input2_ndim, input2_shape_offsets, input2_ndim);
delete[] input2_shape_offsets;

sycl::event event;
sycl::range<1> gws(result_size);

if (use_broadcasting)
{
DPNPC_id<bool>* condition_it;
const size_t condition_it_it_size_in_bytes = sizeof(DPNPC_id<bool>);
condition_it = reinterpret_cast<DPNPC_id<bool>*>(dpnp_memory_alloc_c(q_ref, condition_it_it_size_in_bytes));
new (condition_it) DPNPC_id<bool>(q_ref, condition_data, condition_shape, condition_strides, condition_ndim);

condition_it->broadcast_to_shape(result_shape, result_ndim);

DPNPC_id<_DataType_input1>* input1_it;
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes));
new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim);

input1_it->broadcast_to_shape(result_shape, result_ndim);

DPNPC_id<_DataType_input2>* input2_it;
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes));
new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim);

input2_it->broadcast_to_shape(result_shape, result_ndim);

auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
{
const bool condition = (*condition_it)[i];
const _DataType_output input1_elem = (*input1_it)[i];
const _DataType_output input2_elem = (*input2_it)[i];
result[i] = (condition) ? input1_elem : input2_elem;
}
};
auto kernel_func = [&](sycl::handler& cgh) {
cgh.parallel_for<class dpnp_where_c_broadcast_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
gws, kernel_parallel_for_func);
};

q.submit(kernel_func).wait();

condition_it->~DPNPC_id();
input1_it->~DPNPC_id();
input2_it->~DPNPC_id();

return event_ref;
}
else if (use_strides)
{
if ((result_ndim != condition_ndim) || (result_ndim != input1_ndim) || (result_ndim != input2_ndim))
{
throw std::runtime_error("Result ndim=" + std::to_string(result_ndim) +
" mismatches with either condition ndim=" + std::to_string(condition_ndim) +
" or input1 ndim=" + std::to_string(input1_ndim) +
" or input2 ndim=" + std::to_string(input2_ndim));
}

/* memory transfer optimization, use USM-host for temporary speeds up tranfer to device */
using usm_host_allocatorT = sycl::usm_allocator<shape_elem_type, sycl::usm::alloc::host>;

size_t strides_size = 4 * result_ndim;
shape_elem_type* dev_strides_data = sycl::malloc_device<shape_elem_type>(strides_size, q);

/* create host temporary for packed strides managed by shared pointer */
auto strides_host_packed =
std::vector<shape_elem_type, usm_host_allocatorT>(strides_size, usm_host_allocatorT(q));

/* packed vector is concatenation of result_strides, condition_strides, input1_strides and input2_strides */
std::copy(result_strides, result_strides + result_ndim, strides_host_packed.begin());
std::copy(condition_strides, condition_strides + result_ndim, strides_host_packed.begin() + result_ndim);
std::copy(input1_strides, input1_strides + result_ndim, strides_host_packed.begin() + 2 * result_ndim);
std::copy(input2_strides, input2_strides + result_ndim, strides_host_packed.begin() + 3 * result_ndim);

auto copy_strides_ev =
q.copy<shape_elem_type>(strides_host_packed.data(), dev_strides_data, strides_host_packed.size());

auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
{
const shape_elem_type* result_strides_data = &dev_strides_data[0];
const shape_elem_type* condition_strides_data = &dev_strides_data[result_ndim];
const shape_elem_type* input1_strides_data = &dev_strides_data[2 * result_ndim];
const shape_elem_type* input2_strides_data = &dev_strides_data[3 * result_ndim];

size_t condition_id = 0;
size_t input1_id = 0;
size_t input2_id = 0;

for (size_t i = 0; i < result_ndim; ++i)
{
const size_t output_xyz_id =
get_xyz_id_by_id_inkernel(output_id, result_strides_data, result_ndim, i);
condition_id += output_xyz_id * condition_strides_data[i];
input1_id += output_xyz_id * input1_strides_data[i];
input2_id += output_xyz_id * input2_strides_data[i];
}

const bool condition = condition_data[condition_id];
const _DataType_output input1_elem = input1_data[input1_id];
const _DataType_output input2_elem = input2_data[input2_id];
result[output_id] = (condition) ? input1_elem : input2_elem;
}
};
auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(copy_strides_ev);
cgh.parallel_for<class dpnp_where_c_strides_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
gws, kernel_parallel_for_func);
};

q.submit(kernel_func).wait();

sycl::free(dev_strides_data, q);
return event_ref;
}
else
{
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */

const bool condition = condition_data[i];
const _DataType_output input1_elem = input1_data[i];
const _DataType_output input2_elem = input2_data[i];
result[i] = (condition) ? input1_elem : input2_elem;
};
auto kernel_func = [&](sycl::handler& cgh) {
cgh.parallel_for<class dpnp_where_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
gws, kernel_parallel_for_func);
};
event = q.submit(kernel_func);
}

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
return DPCTLEvent_Copy(event_ref);

return event_ref;
}

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
DPCTLSyclEventRef (*dpnp_where_ext_c)(DPCTLSyclQueueRef,
void*,
const size_t,
const size_t,
const shape_elem_type*,
const shape_elem_type*,
const void*,
const size_t,
const size_t,
const shape_elem_type*,
const shape_elem_type*,
const void*,
const size_t,
const size_t,
const shape_elem_type*,
const shape_elem_type*,
const void*,
const size_t,
const size_t,
const shape_elem_type*,
const shape_elem_type*,
const DPCTLEventVectorRef) = dpnp_where_c<_DataType_output, _DataType_input1, _DataType_input2>;

template <DPNPFuncType FT1, DPNPFuncType... FTs>
static void func_map_searching_2arg_3type_core(func_map_t& fmap)
{
((fmap[DPNPFuncName::DPNP_FN_WHERE_EXT][FT1][FTs] =
{populate_func_types<FT1, FTs>(),
(void*)dpnp_where_ext_c<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>}),
...);
}

template <DPNPFuncType... FTs>
static void func_map_searching_2arg_3type_helper(func_map_t& fmap)
{
((func_map_searching_2arg_3type_core<FTs, FTs...>(fmap)), ...);
}

void func_map_init_searching(func_map_t& fmap)
{
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_argmax_default_c<int32_t, int32_t>};
Expand Down Expand Up @@ -430,7 +177,5 @@ void func_map_init_searching(func_map_t& fmap)
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_argmin_ext_c<double, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_argmin_ext_c<double, int64_t>};

func_map_searching_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap);

return;
}
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_VANDER_EXT
DPNP_FN_VAR
DPNP_FN_VAR_EXT
DPNP_FN_WHERE_EXT
DPNP_FN_ZEROS
DPNP_FN_ZEROS_LIKE

Expand Down Expand Up @@ -577,7 +576,6 @@ Searching functions
"""
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_where(dpnp_descriptor cond_obj, dpnp_descriptor x_obj, dpnp_descriptor y_obj)

"""
Trigonometric functions
Expand Down
Loading