Skip to content

Commit

Permalink
Enabled complex number in dpjit.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingjie-intel committed May 10, 2023
1 parent 0e7f98f commit 28d7234
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 4 deletions.
33 changes: 29 additions & 4 deletions numba_dpex/core/passes/parfor_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def _submit_parfor_kernel(
if isinstance(arg_type, DpnpNdArray):
# FIXME: Remove magic constants
num_flattened_args += 5 + (2 * arg_type.ndim)
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_args += 2
else:
num_flattened_args += 1

Expand Down Expand Up @@ -97,10 +99,33 @@ def _submit_parfor_kernel(
# FIXME: Get rid of magic constants
kernel_arg_num += 5 + (2 * argtype.ndim)
else:
ir_builder.build_arg(
llvm_val, argtype, args_list, args_ty_list, kernel_arg_num
)
kernel_arg_num += 1
if argtype == types.complex64:
ir_builder.build_complex_arg(
llvm_val,
types.float32,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
elif argtype == types.complex128:
ir_builder.build_complex_arg(
llvm_val,
types.float64,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
else:
ir_builder.build_arg(
llvm_val,
argtype,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 1

# Create a global range over which to submit the kernel based on the
# loop_ranges of the parfor
Expand Down
23 changes: 23 additions & 0 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ def build_arg(self, val, ty, arg_list, args_ty_list, arg_num):
numba_type_to_dpctl_typenum(self.context, ty), kernel_arg_ty_dst
)

def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num):
"""Creates a list of LLVM Values for an unpacked complex kernel
argument.
"""
self._build_array_attr_arg(
array_val=val,
array_attr_pos=0,
array_attr_ty=ty,
arg_list=arg_list,
args_ty_list=args_ty_list,
arg_num=arg_num,
)
arg_num += 1
self._build_array_attr_arg(
array_val=val,
array_attr_pos=1,
array_attr_ty=ty,
arg_list=arg_list,
args_ty_list=args_ty_list,
arg_num=arg_num,
)
arg_num += 1

def build_array_arg(
self, array_val, array_rank, arg_list, args_ty_list, arg_num
):
Expand Down
91 changes: 91 additions & 0 deletions numba_dpex/tests/dpjit_tests/test_dpjit_complex_arg_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import numba as nb
import numpy
import pytest

import numba_dpex as dpex

N = 1024


@dpex.dpjit
def prange_arg(a, b, c):
for i in nb.prange(a.shape[0]):
b[i] = a[i] * c


@dpex.dpjit
def prange_array(a, b, c):
for i in nb.prange(a.shape[0]):
b[i] = a[i] * c[i]


list_of_dtypes = [
dpnp.complex64,
dpnp.complex128,
]

list_of_usm_types = ["shared", "device", "host"]


@pytest.fixture(params=list_of_dtypes)
def input_arrays(request):
a = dpnp.ones(N, dtype=request.param)
c = dpnp.zeros(N, dtype=request.param)
b = dpnp.empty_like(a)
return a, b, c


def test_dpjit_scalar_arg_types(input_arrays):
"""Tests passing float and complex type dpnp arrays to a dpjit prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 2
a, b, _ = input_arrays

prange_arg(a, b, s)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=2)

assert numpy.allclose(nb, nexpected)


def test_dpjit_arg_complex_scalar(input_arrays):
"""Tests passing complex type scalar and dpnp arrays to a dpjit prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""
s = 2 + 1j
a, b, _ = input_arrays

prange_arg(a, b, s)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=2 + 1j)

assert numpy.allclose(nb, nexpected)


def test_dpjit_arg_complex_array(input_arrays):
"""Tests passing complex type dpnp arrays to a dpjit prange function.
Args:
input_arrays (dpnp.ndarray): Array arguments to be passed to a kernel.
"""

a, b, c = input_arrays

prange_array(a, b, c)

nb = dpnp.asnumpy(b)
nexpected = numpy.full_like(nb, fill_value=0 + 0j)

assert numpy.allclose(nb, nexpected)

0 comments on commit 28d7234

Please sign in to comment.