diff --git a/numba_dpex/core/kernel_interface/arg_pack_unpacker.py b/numba_dpex/core/kernel_interface/arg_pack_unpacker.py index cc77f92332..0ca6395bb6 100644 --- a/numba_dpex/core/kernel_interface/arg_pack_unpacker.py +++ b/numba_dpex/core/kernel_interface/arg_pack_unpacker.py @@ -193,9 +193,9 @@ def _unpack_argument(self, ty, val, access_specifier): elif ty == types.boolean: return ctypes.c_uint8(int(val)) elif ty == types.complex64: - raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name) + return [ctypes.c_float(val.real), ctypes.c_float(val.imag)] elif ty == types.complex128: - raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name) + return [ctypes.c_double(val.real), ctypes.c_double(val.imag)] else: raise UnsupportedKernelArgumentError(ty, val, self._pyfunc_name) diff --git a/numba_dpex/tests/kernel_tests/test_complex_array.py b/numba_dpex/tests/kernel_tests/test_complex_array.py new file mode 100644 index 0000000000..3fba0eca60 --- /dev/null +++ b/numba_dpex/tests/kernel_tests/test_complex_array.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpnp +import numpy +import pytest + +import numba_dpex as dpex + +N = 1024 + + +@dpex.kernel +def kernel_scalar(a, b, c): + i = dpex.get_global_id(0) + b[i] = a[i] * c + + +@dpex.kernel +def kernel_array(a, b, c): + i = dpex.get_global_id(0) + b[i] = a[i] * c[i] + + +list_of_dtypes = [ + dpnp.complex64, + dpnp.complex128, +] + +list_of_usm_types = ["shared", "device", "host"] + + +@pytest.fixture(params=list_of_dtypes) +def input_arrays(request): + a = dpnp.ones(N, dtype=request.param) + c = dpnp.zeros(N, dtype=request.param) + b = dpnp.empty_like(a) + return a, b, c + + +def test_numeric_kernel_arg_complex_scalar(input_arrays): + """Tests passing complex type scalar and dpnp arrays to a kernel function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + s = 2 + 1j + a, b, _ = input_arrays + + kernel_scalar[dpex.Range(N)](a, b, s) + + nb = dpnp.asnumpy(b) + nexpected = numpy.full_like(nb, fill_value=2 + 1j) + + assert numpy.allclose(nb, nexpected) + + +def test_numeric_kernel_arg_complex_array(input_arrays): + """Tests passing complex type dpnp arrays to a kernel function. + + Args: + input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel. + """ + + a, b, c = input_arrays + + kernel_array[dpex.Range(N)](a, b, c) + + nb = dpnp.asnumpy(b) + nexpected = numpy.full_like(nb, fill_value=0 + 0j) + + assert numpy.allclose(nb, nexpected) diff --git a/numba_dpex/tests/kernel_tests/test_scalar_arg_types.py b/numba_dpex/tests/kernel_tests/test_scalar_arg_types.py index 6fbc80a062..5d6ffda255 100644 --- a/numba_dpex/tests/kernel_tests/test_scalar_arg_types.py +++ b/numba_dpex/tests/kernel_tests/test_scalar_arg_types.py @@ -31,6 +31,8 @@ def kernel_with_bool_arg(a, b, test): dpnp.int64, dpnp.float32, dpnp.float64, + dpnp.complex64, + dpnp.complex128, ] list_of_usm_types = ["shared", "device", "host"] @@ -43,8 +45,8 @@ def input_arrays(request): return a, b -def test_numeric_kernel_arg_types(input_arrays): - """Tests passing float and int type scalar arguments to a kernel function. +def test_numeric_kernel_arg_types1(input_arrays): + """Tests passing float, int and complex type dpnp arrays to a kernel function. Args: input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.