Skip to content

Commit

Permalink
Stop type instantiation during the compilation of hpp files
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Oct 3, 2023
1 parent f62851c commit a604c3b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 44 deletions.
1 change: 1 addition & 0 deletions numba_dpex/core/runtime/kernels/dispatch_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ class DispatchVectorBuilder
} // namespace dispatch_vector
} // namespace runtime
} // namespace ndpx

#endif
25 changes: 8 additions & 17 deletions numba_dpex/core/runtime/kernels/linear_sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@
#include "dispatch_vector.hpp"
#include "type_utils.hpp"

// using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
// static lin_space_step_fn_ptr_t
// lin_space_step_dispatch_vector[ndpx::runtime::type_utils::num_types];
static ndpx::runtime::tensor::lin_space_step_opaque_ptr_t
lin_space_step_dispatch_vector[ndpx::runtime::type_utils::num_types];

void init_linear_sequences_dispatch_vectors(void)
{
using ndpx::runtime::dispatch_vector::DispatchVectorBuilder;
using ndpx::runtime::type_utils::num_types;
// using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory;
using ndpx::runtime::tensor::LinSpaceStepFactory;

// DispatchVectorBuilder<lin_space_step_fn_ptr_t, LinSpaceStepFactory,
// num_types>
// dvb1;
// dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);

// DispatchVectorBuilder<lin_space_affine_fn_ptr_t, LinSpaceAffineFactory,
// num_types>
// dvb2;
// dvb2.populate_dispatch_vector(lin_space_affine_dispatch_vector);
ndpx::runtime::dispatch_vector::DispatchVectorBuilder<
ndpx::runtime::tensor::lin_space_step_opaque_ptr_t,
ndpx::runtime::tensor::LinSpaceStepFactory,
ndpx::runtime::type_utils::num_types>
dvb1;
dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);
}
62 changes: 43 additions & 19 deletions numba_dpex/core/runtime/kernels/linear_sequences.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "type_utils.hpp"
#include <CL/sycl.hpp>
#include <complex>
#include <exception>

namespace ndpx
{
Expand All @@ -13,8 +14,6 @@ namespace tensor
{

template <typename Ty> class linear_sequence_step_kernel;
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
template <typename Ty> class eye_kernel;

template <typename Ty> class LinearSequenceStepFunctor
{
Expand All @@ -32,7 +31,8 @@ template <typename Ty> class LinearSequenceStepFunctor
void operator()(sycl::id<1> wiid) const
{
auto i = wiid.get(0);
if (ndpxutils::is_complex<Ty>::value) {
bool _is_complex = ndpx::runtime::type_utils::is_complex<Ty>::value;
if (_is_complex) {
p[i] = Ty{start_v.real() + i * step_v.real(),
start_v.imag() + i * step_v.imag()};
}
Expand All @@ -42,21 +42,13 @@ template <typename Ty> class LinearSequenceStepFunctor
}
};

// typedef sycl::event (*lin_space_step_fn_ptr_t)(
// sycl::queue &,
// size_t, // num_elements
// const py::object &start,
// const py::object &step,
// char *, // dst_data_ptr
// const std::vector<sycl::event> &);

template <typename Ty>
sycl::event lin_space_step_impl(sycl::queue exec_q,
size_t nelems,
Ty start_v,
Ty step_v,
char *array_data,
const std::vector<sycl::event> &depends)
sycl::event lin_space_step_specialized(sycl::queue exec_q,
size_t nelems,
Ty start_v,
Ty step_v,
char *array_data,
const std::vector<sycl::event> &depends)
{
ndpx::runtime::type_utils::validate_type_for_device<Ty>(exec_q);
sycl::event lin_space_step_event = exec_q.submit([&](sycl::handler &cgh) {
Expand All @@ -69,18 +61,50 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
return lin_space_step_event;
}

extern void init_linear_sequences_dispatch_vectors(void);
template <typename Ty>
sycl::event lin_space_step_opaque(sycl::queue &exec_q,
size_t nelems,
void *start,
void *step,
char *array_data,
const std::vector<sycl::event> &depends)
{
Ty *start_v;
Ty *step_v;
try {
start_v = reinterpret_cast<Ty *>(start);
step_v = reinterpret_cast<Ty *>(step);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
}

auto lin_space_step_event = lin_space_step_specialized<Ty>(
exec_q, nelems, *start_v, *step_v, array_data, depends);

return lin_space_step_event;
}

template <typename fnT, typename Ty> struct LinSpaceStepFactory
{
fnT get()
{
fnT f = lin_space_step_impl<Ty>;
fnT f = lin_space_step_opaque<Ty>;
return f;
}
};

typedef sycl::event (*lin_space_step_opaque_ptr_t)(
sycl::queue &,
size_t, // num_elements
void *, // start_v
void *, // end_v
char *, // dst_data_ptr
const std::vector<sycl::event> &);

extern void init_linear_sequences_dispatch_vectors(void);

} // namespace tensor
} // namespace runtime
} // namespace ndpx

#endif
17 changes: 9 additions & 8 deletions numba_dpex/core/runtime/kernels/type_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ namespace runtime
namespace type_utils
{

template <class T> struct is_complex : public std::false_type
{
};

template <class T> struct is_complex<std::complex<T>> : public std::true_type
{
};

enum class typenum_t : int
{
BOOL = 0,
Expand All @@ -30,15 +38,8 @@ enum class typenum_t : int
CFLOAT,
CDOUBLE, // 13
};
constexpr int num_types = 14; // number of elements in typenum_t

template <class T> struct is_complex : public std::false_type
{
};

template <class T> struct is_complex<std::complex<T>> : public std::true_type
{
};
constexpr int num_types = 14; // number of elements in typenum_t

template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
{
Expand Down

0 comments on commit a604c3b

Please sign in to comment.