Skip to content

Commit

Permalink
Added affine sequence step functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Oct 5, 2023
1 parent 07a5c8e commit 70f4ed9
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 23 deletions.
23 changes: 18 additions & 5 deletions numba_dpex/core/runtime/kernels/sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,29 @@
#include "dispatch.hpp"
#include "types.hpp"

static ndpx::runtime::kernel::tensor::sequence_step_opaque_ptr_t
static ndpx::runtime::kernel::tensor::sequence_step_ptr_t
sequence_step_dispatch_vector[ndpx::runtime::kernel::types::num_types];

static ndpx::runtime::kernel::tensor::affine_sequence_step_ptr_t
affine_sequence_step_dispatch_vector
[npdx::runtime::kernel::types::num_types];

void init_sequence_dispatch_vectors(void)
{

ndpx::runtime::kernel::dispatch::DispatchVectorBuilder<
ndpx::runtime::kernel::tensor::sequence_step_opaque_ptr_t,
ndpx::runtime::kernel::tensor::sequence_step_ptr_t,
ndpx::runtime::kernel::tensor::SequenceStepFactory,
ndpx::runtime::kernel::types::num_types>
dvb1;
dvb1.populate_dispatch_vector(sequence_step_dispatch_vector);
dvb;
dvb.populate_dispatch_vector(sequence_step_dispatch_vector);
}

void init_affine_sequence_dispatch_vectors(void)
{
ndpx::runtime::kernel::dispatch::DispatchVectorBuilder<
ndpx::runtime::kernel::tensor::affine_sequence_step_ptr_t,
ndpx::runtime::kernel::tensor::AffineSequenceStepFactory,
ndpx::runtime::kernel::types::num_types>
dvb;
dvb.populate_dispatch_vector(affine_sequence_step_dispatch_vector);
}
161 changes: 143 additions & 18 deletions numba_dpex/core/runtime/kernels/sequences.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ namespace kernel
namespace tensor
{

template <typename Ty> class sequence_step_kernel;
template <typename Ty> class ndpx_sequence_step_kernel;
template <typename Ty> class ndpx_affine_sequence_step_kernel;

template <typename Ty> class SequenceStepFunctor
{
Expand Down Expand Up @@ -44,18 +45,69 @@ template <typename Ty> class SequenceStepFunctor
}
};

template <typename Ty, typename wTy> class AffineSequenceStepFunctor
{
private:
Ty *p = nullptr;
Ty start_v;
Ty end_v;
size_t n;

public:
AffineSequenceStepFunctor(char *dst_p, Ty v0, Ty v1, size_t den)
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), end_v(v1),
n((den == 0) ? 1 : den)
{
}

void operator()(sycl::id<1> wiid) const
{
auto i = wiid.get(0);
wTy wc = wTy(i) / n;
wTy w = wTy(n - i) / n;
if constexpr (ndpx::runtime::kernel::types::is_complex<Ty>::value) {
using reT = typename Ty::value_type;
auto _w = static_cast<reT>(w);
auto _wc = static_cast<reT>(wc);
auto re_comb = sycl::fma(start_v.real(), _w, reT(0));
re_comb =
sycl::fma(end_v.real(), _wc,
re_comb); // start_v.real() * _w + end_v.real() * _wc;
auto im_comb =
sycl::fma(start_v.imag(), _w,
reT(0)); // start_v.imag() * _w + end_v.imag() * _wc;
im_comb = sycl::fma(end_v.imag(), _wc, im_comb);
Ty affine_comb = Ty{re_comb, im_comb};
p[i] = affine_comb;
}
else if constexpr (std::is_floating_point<Ty>::value) {
Ty _w = static_cast<Ty>(w);
Ty _wc = static_cast<Ty>(wc);
auto affine_comb =
sycl::fma(start_v, _w, Ty(0)); // start_v * w + end_v * wc;
affine_comb = sycl::fma(end_v, _wc, affine_comb);
p[i] = affine_comb;
}
else {
auto affine_comb = start_v * w + end_v * wc;
p[i] = ndpx::runtime::kernel::types::convert_impl<
Ty, decltype(affine_comb)>(affine_comb);
}
}
};

template <typename Ty>
sycl::event sequence_step_specialized(sycl::queue exec_q,
size_t nelems,
Ty start_v,
Ty step_v,
char *array_data,
const std::vector<sycl::event> &depends)
sycl::event sequence_step_kernel(sycl::queue exec_q,
size_t nelems,
Ty start_v,
Ty step_v,
char *array_data,
const std::vector<sycl::event> &depends)
{
ndpx::runtime::kernel::types::validate_type_for_device<Ty>(exec_q);
sycl::event seq_step_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<sequence_step_kernel<Ty>>(
cgh.parallel_for<ndpx_sequence_step_kernel<Ty>>(
sycl::range<1>{nelems},
SequenceStepFunctor<Ty>(array_data, start_v, step_v));
});
Expand All @@ -64,46 +116,119 @@ sycl::event sequence_step_specialized(sycl::queue exec_q,
}

template <typename Ty>
sycl::event sequence_step_opaque(sycl::queue &exec_q,
sycl::event affine_sequence_step_kernel(sycl::queue &exec_q,
size_t nelems,
Ty start_v,
Ty end_v,
bool include_endpoint,
char *array_data,
const std::vector<sycl::event> &depends)
{
ndpx::runtime::kernel::types::validate_type_for_device<Ty>(exec_q);
bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
sycl::event affine_seq_step_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
if (device_supports_doubles) {
cgh.parallel_for<ndpx_affine_sequence_step_kernel<Ty, double>>(
sycl::range<1>{nelems},
AffineSequenceStepFunctor<Ty, double>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
}
else {
cgh.parallel_for<ndpx_affine_sequence_step_kernel<Ty, float>>(
sycl::range<1>{nelems},
AffineSequenceStepFunctor<Ty, float>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
}
});

return affine_seq_step_event;
}

template <typename Ty>
sycl::event sequence_step(sycl::queue &exec_q,
size_t nelems,
void *start,
void *step,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty *start_v, *step_v;
try {
start_v = reinterpret_cast<Ty *>(start);
step_v = reinterpret_cast<Ty *>(step);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
}

auto sequence_step_event = sequence_step_kernel<Ty>(
exec_q, nelems, *start_v, *step_v, array_data, depends);

return sequence_step_event;
}

template <typename Ty>
sycl::event affine_sequence_step(sycl::queue &exec_q,
size_t nelems,
void *start,
void *step,
void *end,
bool include_endpoint,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty *start_v;
Ty *step_v;
Ty *start_v, *end_v;
try {
start_v = reinterpret_cast<Ty *>(start);
step_v = reinterpret_cast<Ty *>(step);
end_v = reinterpret_cast<Ty *>(end);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
}

auto sequence_step_event = sequence_step_specialized<Ty>(
exec_q, nelems, *start_v, *step_v, array_data, depends);
auto affine_sequence_step_event =
affine_sequence_step_kernel<Ty>(exec_q, nelems, *start_v, *end_v,
include_endpoint, array_data, depends);

return sequence_step_event;
return affine_sequence_step_event;
}

template <typename fnT, typename Ty> struct SequenceStepFactory
{
fnT get()
{
fnT f = sequence_step_opaque<Ty>;
fnT f = sequence_step<Ty>;
return f;
}
};

template <typename fnT, typename Ty> struct AffineSequenceStepFactory
{
fnT get()
{
fnT f = affine_sequence_step<Ty>;
return f;
}
};

typedef sycl::event (*sequence_step_opaque_ptr_t)(
typedef sycl::event (*sequence_step_ptr_t)(sycl::queue &,
size_t, // num_elements
void *, // start_v
void *, // end_v
char *, // dst_data_ptr
const std::vector<sycl::event> &);

typedef sycl::event (*affine_sequence_step_ptr_t)(
sycl::queue &,
size_t, // num_elements
void *, // start_v
void *, // end_v
bool, // include_endpoint
char *, // dst_data_ptr
const std::vector<sycl::event> &);

extern void init_sequence_dispatch_vectors(void);
extern void init_affine_sequence_dispatch_vectors(void);

} // namespace tensor
} // namespace kernel
Expand Down

0 comments on commit 70f4ed9

Please sign in to comment.