Skip to content

Commit

Permalink
Fixed SequenceStepFunctor
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Oct 4, 2023
1 parent ee71575 commit db14130
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 187 deletions.
60 changes: 60 additions & 0 deletions numba_dpex/core/runtime/kernels/dispatch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#ifndef __DISPATCH_HPP__
#define __DISPATCH_HPP__

namespace ndpx
{
namespace runtime
{
namespace kernel
{
namespace dispatch
{

template <typename funcPtrT,
template <typename fnT, typename T>
typename factory,
int _num_types>
class DispatchVectorBuilder
{
private:
template <typename Ty> const funcPtrT func_per_type() const
{
funcPtrT f = factory<funcPtrT, Ty>{}.get();
return f;
}

public:
DispatchVectorBuilder() = default;
~DispatchVectorBuilder() = default;

void populate_dispatch_vector(funcPtrT vector[]) const
{
const auto fn_map_by_type = {func_per_type<bool>(),
func_per_type<int8_t>(),
func_per_type<uint8_t>(),
func_per_type<int16_t>(),
func_per_type<uint16_t>(),
func_per_type<int32_t>(),
func_per_type<uint32_t>(),
func_per_type<int64_t>(),
func_per_type<uint64_t>(),
func_per_type<sycl::half>(),
func_per_type<float>(),
func_per_type<double>(),
func_per_type<std::complex<float>>(),
func_per_type<std::complex<double>>()};
assert(fn_map_by_type.size() == _num_types);
int ty_id = 0;
for (auto &fn : fn_map_by_type) {
vector[ty_id] = fn;
++ty_id;
}
}
};

} // namespace dispatch
} // namespace kernel
} // namespace runtime
} // namespace ndpx

#endif
57 changes: 0 additions & 57 deletions numba_dpex/core/runtime/kernels/dispatch_vector.hpp

This file was deleted.

16 changes: 0 additions & 16 deletions numba_dpex/core/runtime/kernels/linear_sequences.cpp

This file was deleted.

110 changes: 0 additions & 110 deletions numba_dpex/core/runtime/kernels/linear_sequences.hpp

This file was deleted.

17 changes: 17 additions & 0 deletions numba_dpex/core/runtime/kernels/sequences.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "sequences.hpp"
#include "dispatch.hpp"
#include "types.hpp"

static ndpx::runtime::kernel::tensor::sequence_step_opaque_ptr_t
sequence_step_dispatch_vector[ndpx::runtime::kernel::types::num_types];

void init_sequence_dispatch_vectors(void)
{

ndpx::runtime::kernel::dispatch::DispatchVectorBuilder<
ndpx::runtime::kernel::tensor::sequence_step_opaque_ptr_t,
ndpx::runtime::kernel::tensor::SequenceStepFactory,
ndpx::runtime::kernel::types::num_types>
dvb1;
dvb1.populate_dispatch_vector(sequence_step_dispatch_vector);
}
113 changes: 113 additions & 0 deletions numba_dpex/core/runtime/kernels/sequences.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#ifndef __SEQUENCES_HPP__
#define __SEQUENCES_HPP__

#include "types.hpp"
#include <CL/sycl.hpp>
#include <complex>
#include <exception>
#include <iostream>

namespace ndpx
{
namespace runtime
{
namespace kernel
{
namespace tensor
{

template <typename Ty> class sequence_step_kernel;

template <typename Ty> class SequenceStepFunctor
{
private:
Ty *p = nullptr;
Ty start_v;
Ty step_v;

public:
SequenceStepFunctor(char *dst_p, Ty v0, Ty dv)
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), step_v(dv)
{
}

void operator()(sycl::id<1> wiid) const
{
auto i = wiid.get(0);
if constexpr (ndpx::runtime::kernel::types::is_complex<Ty>::value) {
p[i] = Ty{start_v.real() + i * step_v.real(),
start_v.imag() + i * step_v.imag()};
}
else {
p[i] = start_v + i * step_v;
}
}
};

template <typename Ty>
sycl::event sequence_step_specialized(sycl::queue exec_q,
size_t nelems,
Ty start_v,
Ty step_v,
char *array_data,
const std::vector<sycl::event> &depends)
{
ndpx::runtime::kernel::types::validate_type_for_device<Ty>(exec_q);
sycl::event seq_step_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<sequence_step_kernel<Ty>>(
sycl::range<1>{nelems},
SequenceStepFunctor<Ty>(array_data, start_v, step_v));
});

return seq_step_event;
}

template <typename Ty>
sycl::event sequence_step_opaque(sycl::queue &exec_q,
size_t nelems,
void *start,
void *step,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty *start_v;
Ty *step_v;
try {
start_v = reinterpret_cast<Ty *>(start);
step_v = reinterpret_cast<Ty *>(step);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
}

auto sequence_step_event = sequence_step_specialized<Ty>(
exec_q, nelems, *start_v, *step_v, array_data, depends);

return sequence_step_event;
}

template <typename fnT, typename Ty> struct SequenceStepFactory
{
fnT get()
{
fnT f = sequence_step_opaque<Ty>;
return f;
}
};

typedef sycl::event (*sequence_step_opaque_ptr_t)(
sycl::queue &,
size_t, // num_elements
void *, // start_v
void *, // end_v
char *, // dst_data_ptr
const std::vector<sycl::event> &);

extern void init_sequence_dispatch_vectors(void);

} // namespace tensor
} // namespace kernel
} // namespace runtime
} // namespace ndpx

#endif
Loading

0 comments on commit db14130

Please sign in to comment.