Skip to content

Commit

Permalink
Allow for change of name and location of sycl_complex.hpp
Browse files Browse the repository at this point in the history
Introduced private header to load SYCL's experimental complex header
from the right location. The header and implementations respond to
USE_SYCL_FOR_COMPLEX_TYPES preprocessor variable. If set, sycl::ext::oneapi::experimental
namespace functions are to be used. Otherwise std:: namespace functions will
be used instead for complex types.

USE_SYCL_FOR_COMPLEX_TYPES is being set in tensor/CMakeLists.txt
  • Loading branch information
oleksandr-pavlyk committed Nov 18, 2023
1 parent 6d3be5d commit 0f76890
Show file tree
Hide file tree
Showing 29 changed files with 228 additions and 69 deletions.
4 changes: 2 additions & 2 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ foreach(_src_fn ${_no_fast_math_sources})
)
endforeach()
if (UNIX)
set(_compiler_definitions "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX")
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
else()
set(_compiler_definitions "SYCL_EXT_ONEAPI_COMPLEX")
set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES")
endif()

foreach(_src_fn ${_elementwise_sources})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
#include <cstddef>
#include <cstdint>
#include <limits>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -50,7 +50,6 @@ namespace abs

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -121,7 +120,7 @@ template <typename argT, typename resT> struct AbsFunctor
return q_nan;
}
else {
#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::abs(exprm_ns::complex<realT>(z));
#else
return std::hypot(std::real(z), std::imag(z));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -48,7 +48,6 @@ namespace acos

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -105,6 +104,7 @@ template <typename argT, typename resT> struct AcosFunctor
constexpr realT r_eps =
realT(1) / std::numeric_limits<realT>::epsilon();
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using sycl_complexT = exprm_ns::complex<realT>;
sycl_complexT log_in =
exprm_ns::log(exprm_ns::complex<realT>(in));
Expand All @@ -115,11 +115,24 @@ template <typename argT, typename resT> struct AcosFunctor

realT ry = wx + std::log(realT(2));
return resT{rx, (std::signbit(y)) ? ry : -ry};
#else
resT log_in = std::log(in);
const realT wx = std::real(log_in);
const realT wy = std::imag(log_in);
const realT rx = std::abs(wy);

realT ry = wx + std::log(realT(2));
return resT{rx, (std::signbit(y)) ? ry : -ry};
#endif
}

/* ordinary cases */
#if USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::acos(
exprm_ns::complex<realT>(in)); // std::acos(in);
#else
return std::acos(in);
#endif
}
else {
static_assert(std::is_floating_point_v<argT> ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -48,7 +48,6 @@ namespace acosh

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -112,18 +111,28 @@ template <typename argT, typename resT> struct AcoshFunctor
* For large x or y including acos(+-Inf + I*+-Inf)
*/
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using sycl_complexT = typename exprm_ns::complex<realT>;
const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in));
const realT wx = log_in.real();
const realT wy = log_in.imag();
#else
const resT log_in = std::log(in);
const realT wx = std::real(log_in);
const realT wy = std::imag(log_in);
#endif
const realT rx = std::abs(wy);
realT ry = wx + std::log(realT(2));
acos_in = resT{rx, (std::signbit(y)) ? ry : -ry};
}
else {
/* ordinary cases */
#if USE_SYCL_FOR_COMPLEX_TYPES
acos_in = exprm_ns::acos(
exprm_ns::complex<realT>(in)); // std::acos(in);
#else
acos_in = std::acos(in);
#endif
}

/* Now we calculate acosh(z) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "sycl_complex.hpp"
#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"
Expand All @@ -50,7 +50,6 @@ namespace add
namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace tu_ns = dpctl::tensor::type_utils;
namespace exprm_ns = sycl::ext::oneapi::experimental;

template <typename argT1, typename argT2, typename resT> struct AddFunctor
{
Expand All @@ -65,24 +64,36 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
if constexpr (tu_ns::is_complex<argT1>::value &&
tu_ns::is_complex<argT2>::value)
{
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using rT1 = typename argT1::value_type;
using rT2 = typename argT2::value_type;

return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
#else
return in1 + in2;
#endif
}
else if constexpr (tu_ns::is_complex<argT1>::value &&
!tu_ns::is_complex<argT2>::value)
{
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using rT1 = typename argT1::value_type;

return exprm_ns::complex<rT1>(in1) + in2;
#else
return in1 + in2;
#endif
}
else if constexpr (!tu_ns::is_complex<argT1>::value &&
tu_ns::is_complex<argT2>::value)
{
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using rT2 = typename argT2::value_type;

return in1 + exprm_ns::complex<rT2>(in2);
#else
return in1 + in2;
#endif
}
else {
return in1 + in2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -48,7 +48,6 @@ namespace asin

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -119,26 +118,45 @@ template <typename argT, typename resT> struct AsinFunctor
constexpr realT r_eps =
realT(1) / std::numeric_limits<realT>::epsilon();
if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using sycl_complexT = exprm_ns::complex<realT>;
const sycl_complexT z{x, y};
realT wx, wy;
if (!std::signbit(x)) {
auto log_z = exprm_ns::log(z);
const auto log_z = exprm_ns::log(z);
wx = log_z.real() + std::log(realT(2));
wy = log_z.imag();
}
else {
auto log_mz = exprm_ns::log(-z);
const auto log_mz = exprm_ns::log(-z);
wx = log_mz.real() + std::log(realT(2));
wy = log_mz.imag();
}
#else
const resT z{x, y};
realT wx, wy;
if (!std::signbit(x)) {
const auto log_z = std::log(z);
wx = std::real(log_z) + std::log(realT(2));
wy = std::imag(log_z);
}
else {
const auto log_mz = std::log(-z);
wx = std::real(log_mz) + std::log(realT(2));
wy = std::imag(log_mz);
}
#endif
const realT asinh_re = std::copysign(wx, x);
const realT asinh_im = std::copysign(wy, y);
return resT{asinh_im, asinh_re};
}
/* ordinary cases */
#if USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::asin(
exprm_ns::complex<realT>(in)); // std::asin(in);
#else
return std::asin(in);
#endif
}
else {
static_assert(std::is_floating_point_v<argT> ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -48,7 +48,6 @@ namespace asinh

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -108,20 +107,30 @@ template <typename argT, typename resT> struct AsinhFunctor
realT(1) / std::numeric_limits<realT>::epsilon();

if (std::abs(x) > r_eps || std::abs(y) > r_eps) {
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
using sycl_complexT = exprm_ns::complex<realT>;
sycl_complexT log_in = (std::signbit(x))
? exprm_ns::log(sycl_complexT(-in))
: exprm_ns::log(sycl_complexT(in));
realT wx = log_in.real() + std::log(realT(2));
realT wy = log_in.imag();
#else
auto log_in = std::log(std::signbit(x) ? -in : in);
realT wx = std::real(log_in) + std::log(realT(2));
realT wy = std::imag(log_in);
#endif
const realT res_re = std::copysign(wx, x);
const realT res_im = std::copysign(wy, y);
return resT{res_re, res_im};
}

/* ordinary cases */
#if USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::asinh(
exprm_ns::complex<realT>(in)); // std::asinh(in);
#else
return std::asinh(in);
#endif
}
else {
static_assert(std::is_floating_point_v<argT> ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
#include <complex>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -49,7 +49,6 @@ namespace atan

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -128,8 +127,12 @@ template <typename argT, typename resT> struct AtanFunctor
return resT{atanh_im, atanh_re};
}
/* ordinary cases */
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::atan(
exprm_ns::complex<realT>(in)); // std::atan(in);
#else
return std::atan(in);
#endif
}
else {
static_assert(std::is_floating_point_v<argT> ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
#include <complex>
#include <cstddef>
#include <cstdint>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/sycl.hpp>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
#include "sycl_complex.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
Expand All @@ -49,7 +49,6 @@ namespace atanh

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
namespace exprm_ns = sycl::ext::oneapi::experimental;

using dpctl::tensor::type_utils::is_complex;

Expand Down Expand Up @@ -121,8 +120,12 @@ template <typename argT, typename resT> struct AtanhFunctor
return resT{res_re, res_im};
}
/* ordinary cases */
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
return exprm_ns::atanh(
exprm_ns::complex<realT>(in)); // std::atanh(in);
#else
return std::atanh(in);
#endif
}
else {
static_assert(std::is_floating_point_v<argT> ||
Expand Down
Loading

0 comments on commit 0f76890

Please sign in to comment.