Skip to content

Commit

Permalink
Porting dispatch vector into numba-dpex, working on building generic …
Browse files Browse the repository at this point in the history
…pointer
  • Loading branch information
chudur-budur committed Oct 3, 2023
1 parent 5924b4f commit 6173e73
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ project(numba-dpex
VERSION ${NUMBA_DPEX_VERSION}
)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)

if(IS_INSTALL)
install(DIRECTORY numba_dpex
DESTINATION ${CMAKE_INSTALL_PREFIX}
Expand Down
8 changes: 5 additions & 3 deletions numba_dpex/core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ message(STATUS "CMAKE_MODULE_PATH=" "${CMAKE_MODULE_PATH}")

# Add packages
find_package(Python 3.9 REQUIRED
COMPONENTS Interpreter Development.Module NumPy)
COMPONENTS Interpreter Development.Module)
find_package(Dpctl REQUIRED)
find_package(NumPy REQUIRED)

Expand All @@ -89,10 +89,12 @@ include_directories(${Python_INCLUDE_DIRS})
include_directories(${NumPy_INCLUDE_DIRS})
include_directories(${Numba_INCLUDE_DIRS})
include_directories(${Dpctl_INCLUDE_DIRS})
include_directories(.)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kernels)

# Source files, *.c
file(GLOB SOURCES "*.c")
# file(GLOB SOURCES "*.c")
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.c" "*.cpp")

# Link dpctl library path with -L
link_directories(${DPCTL_LIBRARY_PATH})
Expand Down
56 changes: 56 additions & 0 deletions numba_dpex/core/runtime/kernels/dispatch_vector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef __DISPATCH_VECTOR_HPP__
#define __DISPATCH_VECTOR_HPP__

namespace ndpx
{
namespace runtime
{
namespace dispatch_vector
{

template <typename funcPtrT,
template <typename fnT, typename T>
typename factory,
int _num_types>
class DispatchVectorBuilder
{
private:
template <typename Ty> const funcPtrT func_per_type() const
{
funcPtrT f = factory<funcPtrT, Ty>{}.get();
return f;
}

public:
DispatchVectorBuilder() = default;
~DispatchVectorBuilder() = default;

void populate_dispatch_vector(funcPtrT vector[]) const
{
const auto fn_map_by_type = {func_per_type<bool>(),
func_per_type<int8_t>(),
func_per_type<uint8_t>(),
func_per_type<int16_t>(),
func_per_type<uint16_t>(),
func_per_type<int32_t>(),
func_per_type<uint32_t>(),
func_per_type<int64_t>(),
func_per_type<uint64_t>(),
func_per_type<sycl::half>(),
func_per_type<float>(),
func_per_type<double>(),
func_per_type<std::complex<float>>(),
func_per_type<std::complex<double>>()};
assert(fn_map_by_type.size() == _num_types);
int ty_id = 0;
for (auto &fn : fn_map_by_type) {
vector[ty_id] = fn;
++ty_id;
}
}
};

} // namespace dispatch_vector
} // namespace runtime
} // namespace ndpx
#endif
24 changes: 24 additions & 0 deletions numba_dpex/core/runtime/kernels/linear_sequences.cpp
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
#include "linear_sequences.hpp"
#include "dispatch_vector.hpp"
#include "type_utils.hpp"

// using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t;
// static lin_space_step_fn_ptr_t
// lin_space_step_dispatch_vector[ndpx::runtime::type_utils::num_types];

void init_linear_sequences_dispatch_vectors(void)
{
using ndpx::runtime::dispatch_vector::DispatchVectorBuilder;
using ndpx::runtime::type_utils::num_types;
// using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory;
using ndpx::runtime::tensor::LinSpaceStepFactory;

// DispatchVectorBuilder<lin_space_step_fn_ptr_t, LinSpaceStepFactory,
// num_types>
// dvb1;
// dvb1.populate_dispatch_vector(lin_space_step_dispatch_vector);

// DispatchVectorBuilder<lin_space_affine_fn_ptr_t, LinSpaceAffineFactory,
// num_types>
// dvb2;
// dvb2.populate_dispatch_vector(lin_space_affine_dispatch_vector);
}
31 changes: 29 additions & 2 deletions numba_dpex/core/runtime/kernels/linear_sequences.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
#include <CL/sycl.hpp>
#include <complex>

namespace ndpxutils = ndpx::runtime::utils;
namespace ndpx
{
namespace runtime
{
namespace tensor
{

template <typename Ty> class linear_sequence_step_kernel;
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
Expand Down Expand Up @@ -37,6 +42,14 @@ template <typename Ty> class LinearSequenceStepFunctor
}
};

// typedef sycl::event (*lin_space_step_fn_ptr_t)(
// sycl::queue &,
// size_t, // num_elements
// const py::object &start,
// const py::object &step,
// char *, // dst_data_ptr
// const std::vector<sycl::event> &);

template <typename Ty>
sycl::event lin_space_step_impl(sycl::queue exec_q,
size_t nelems,
Expand All @@ -45,7 +58,7 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
ndpxutils::validate_type_for_device<Ty>(exec_q);
ndpx::runtime::type_utils::validate_type_for_device<Ty>(exec_q);
sycl::event lin_space_step_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<linear_sequence_step_kernel<Ty>>(
Expand All @@ -56,4 +69,18 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
return lin_space_step_event;
}

extern void init_linear_sequences_dispatch_vectors(void);

template <typename fnT, typename Ty> struct LinSpaceStepFactory
{
fnT get()
{
fnT f = lin_space_step_impl<Ty>;
return f;
}
};

} // namespace tensor
} // namespace runtime
} // namespace ndpx
#endif
23 changes: 21 additions & 2 deletions numba_dpex/core/runtime/kernels/type_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,28 @@ namespace ndpx
{
namespace runtime
{
namespace utils
namespace type_utils
{

enum class typenum_t : int
{
BOOL = 0,
INT8, // 1
UINT8,
INT16,
UINT16,
INT32, // 5
UINT32,
INT64,
UINT64,
HALF,
FLOAT, // 10
DOUBLE,
CFLOAT,
CDOUBLE, // 13
};
constexpr int num_types = 14; // number of elements in typenum_t

template <class T> struct is_complex : public std::false_type
{
};
Expand Down Expand Up @@ -101,7 +120,7 @@ auto vec_cast(const sycl::vec<srcT, N> &s)
}
}

} // namespace utils
} // namespace type_utils
} // namespace runtime
} // namespace ndpx

Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@ def to_cmake_format(version: str):
),
"-DIS_INSTALL:BOOL={0:s}".format("TRUE" if is_install else "FALSE"),
"-DIS_DEVELOP:BOOL={0:s}".format("TRUE" if is_develop else "FALSE"),
"-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER=icpx",
],
)

0 comments on commit 6173e73

Please sign in to comment.