Skip to content

Commit

Permalink
Merge pull request #1454 from IntelPython/compile-for-cuda
Browse files Browse the repository at this point in the history
Enable compiling for cuda
  • Loading branch information
oleksandr-pavlyk authored Oct 25, 2023
2 parents 71b85ab + 986dc6f commit 479a969
Show file tree
Hide file tree
Showing 26 changed files with 1,732 additions and 31 deletions.
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,30 @@ option(DPCTL_GENERATE_COVERAGE
"Build dpctl with coverage instrumentation"
OFF
)
option(DPCTL_TARGET_CUDA
"Build DPCTL to target CUDA devices"
OFF
)

find_package(IntelSYCL REQUIRED PATHS ${CMAKE_SOURCE_DIR}/cmake NO_DEFAULT_PATH)

set(_dpctl_sycl_targets)
if ("x${DPCTL_SYCL_TARGETS}" STREQUAL "x")
if(DPCTL_TARGET_CUDA)
set(_dpctl_sycl_targets "nvptx64-nvidia-cuda,spir64-unknown-unknown")
else()
if(DEFINED ENV{DPCTL_TARGET_CUDA})
set(_dpctl_sycl_targets "nvptx64-nvidia-cuda,spir64-unknown-unknown")
endif()
endif()
else()
set(_dpctl_sycl_targets ${DPCTL_SYCL_TARGETS})
endif()

if(_dpctl_sycl_targets)
message(STATUS "Compiling for -fsycl-targets=${_dpctl_sycl_targets}")
endif()

add_subdirectory(libsyclinterface)

file(GLOB _dpctl_capi_headers dpctl/apis/include/*.h*)
Expand Down
15 changes: 14 additions & 1 deletion dpctl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,20 @@ function(build_dpctl_ext _trgt _src _dest)
add_custom_target(${_cythonize_trgt} DEPENDS ${_src})
Python_add_library(${_trgt} MODULE WITH_SOABI ${_generated_src})
if (BUILD_DPCTL_EXT_SYCL)
add_sycl_to_target(TARGET ${_trgt} SOURCES ${_generated_src})
add_sycl_to_target(TARGET ${_trgt} SOURCES ${_generated_src})
if(_dpctl_sycl_targets)
# make fat binary
target_compile_options(
${_trgt}
PRIVATE
-fsycl-targets=${_dpctl_sycl_targets}
)
target_link_options(
${_trgt}
PRIVATE
-fsycl-targets=${_dpctl_sycl_targets}
)
endif()
endif()
target_include_directories(${_trgt} PRIVATE ${NumPy_INCLUDE_DIR} ${DPCTL_INCLUDE_DIR})
add_dependencies(${_trgt} _build_time_create_dpctl_include_copy ${_cythonize_trgt})
Expand Down
14 changes: 14 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ set(_tensor_impl_sources
set(python_module_name _tensor_impl)
pybind11_add_module(${python_module_name} MODULE ${_tensor_impl_sources})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_impl_sources})
if(_dpctl_sycl_targets)
# make fat binary
target_compile_options(
${python_module_name}
PRIVATE
-fsycl-targets=${_dpctl_sycl_targets}
)
target_link_options(
${python_module_name}
PRIVATE
-fsycl-targets=${_dpctl_sycl_targets}
)
endif()

set(_clang_prefix "")
if (WIN32)
set(_clang_prefix "/clang:")
Expand Down
8 changes: 8 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,16 @@
bitwise_or,
bitwise_right_shift,
bitwise_xor,
cbrt,
ceil,
conj,
copysign,
cos,
cosh,
divide,
equal,
exp,
exp2,
expm1,
floor,
floor_divide,
Expand Down Expand Up @@ -149,6 +152,7 @@
real,
remainder,
round,
rsqrt,
sign,
signbit,
sin,
Expand Down Expand Up @@ -314,4 +318,8 @@
"argmax",
"argmin",
"prod",
"cbrt",
"exp2",
"copysign",
"rsqrt",
]
113 changes: 113 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,116 @@
hypot = BinaryElementwiseFunc(
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
)


# U37: ==== CBRT (x)
_cbrt_docstring_ = """
cbrt(x, out=None, order='K')
Computes positive cube-root for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a real floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise positive cube-root.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

cbrt = UnaryElementwiseFunc(
"cbrt", ti._cbrt_result_type, ti._cbrt, _cbrt_docstring_
)


# U38: ==== EXP2 (x)
_exp2_docstring_ = """
exp2(x, out=None, order='K')
Computes the base-2 exponential for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise base-2 exponentials.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

exp2 = UnaryElementwiseFunc(
"exp2", ti._exp2_result_type, ti._exp2, _exp2_docstring_
)


# B25: ==== COPYSIGN (x1, x2)
_copysign_docstring_ = """
copysign(x1, x2, out=None, order='K')
Composes a floating-point value with the magnitude of `x1_i` and the sign of
`x2_i` for each element of input arrays `x1` and `x2`.
Args:
x1 (usm_ndarray):
First input array, expected to have a real floating-point data type.
x2 (usm_ndarray):
Second input array, also expected to have a real floating-point data
type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise results. The data type
of the returned array is determined by the Type Promotion Rules.
"""
copysign = BinaryElementwiseFunc(
"copysign",
ti._copysign_result_type,
ti._copysign,
_copysign_docstring_,
)


# U39: ==== RSQRT (x)
_rsqrt_docstring_ = """
rsqrt(x, out=None, order='K')
Computes the reciprocal square-root for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a real floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise reciprocal square-root.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

rsqrt = UnaryElementwiseFunc(
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
)
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ template <typename argT, typename resT> struct AcosFunctor
constexpr realT r_eps =
realT(1) / std::numeric_limits<realT>::epsilon();
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
argT log_in = std::log(in);
using sycl_complexT = exprm_ns::complex<realT>;
sycl_complexT log_in =
exprm_ns::log(exprm_ns::complex<realT>(in));

const realT wx = std::real(log_in);
const realT wy = std::imag(log_in);
const realT wx = log_in.real();
const realT wy = log_in.imag();
const realT rx = std::abs(wy);

realT ry = wx + std::log(realT(2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace acosh

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace cmplx_ns = sycl::ext::oneapi::experimental;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -112,16 +112,18 @@ template <typename argT, typename resT> struct AcoshFunctor
* For large x or y including acos(+-Inf + I*+-Inf)
*/
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
const realT wx = std::real(std::log(in));
const realT wy = std::imag(std::log(in));
using sycl_complexT = typename exprm_ns::complex<realT>;
const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in));
const realT wx = log_in.real();
const realT wy = log_in.imag();
const realT rx = std::abs(wy);
realT ry = wx + std::log(realT(2));
acos_in = resT{rx, (std::signbit(y)) ? ry : -ry};
}
else {
/* ordinary cases */
acos_in = cmplx_ns::acos(
cmplx_ns::complex<realT>(in)); // std::acos(in);
acos_in = exprm_ns::acos(
exprm_ns::complex<realT>(in)); // std::acos(in);
}

/* Now we calculate acosh(z) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,18 @@ template <typename argT, typename resT> struct AsinFunctor
constexpr realT r_eps =
realT(1) / std::numeric_limits<realT>::epsilon();
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
const resT z = {x, y};
using sycl_complexT = exprm_ns::complex<realT>;
const sycl_complexT z{x, y};
realT wx, wy;
if (!std::signbit(x)) {
auto log_z = std::log(z);
wx = std::real(log_z) + std::log(realT(2));
wy = std::imag(log_z);
auto log_z = exprm_ns::log(z);
wx = log_z.real() + std::log(realT(2));
wy = log_z.imag();
}
else {
auto log_mz = std::log(-z);
wx = std::real(log_mz) + std::log(realT(2));
wy = std::imag(log_mz);
auto log_mz = exprm_ns::log(-z);
wx = log_mz.real() + std::log(realT(2));
wy = log_mz.imag();
}
const realT asinh_re = std::copysign(wx, x);
const realT asinh_im = std::copysign(wy, y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ template <typename argT, typename resT> struct AsinhFunctor
realT(1) / std::numeric_limits<realT>::epsilon();

if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
resT log_in = (std::signbit(x)) ? std::log(-in) : std::log(in);
realT wx = std::real(log_in) + std::log(realT(2));
realT wy = std::imag(log_in);
using sycl_complexT = exprm_ns::complex<realT>;
sycl_complexT log_in = (std::signbit(x))
? exprm_ns::log(sycl_complexT(-in))
: exprm_ns::log(sycl_complexT(in));
realT wx = log_in.real() + std::log(realT(2));
realT wy = log_in.imag();
const realT res_re = std::copysign(wx, x);
const realT res_im = std::copysign(wy, y);
return resT{res_re, res_im};
Expand Down
Loading

0 comments on commit 479a969

Please sign in to comment.