Skip to content

Commit

Permalink
Fix 'SYCL feature test compile failed' problem
Browse files Browse the repository at this point in the history
Add _get_xfail_test() to _helper.py

_is_type() --> _match_type()

Fix docstring

Set CMAKE_C##_COMPILER in setup.py

Revert setup.py
  • Loading branch information
chudur-budur committed Dec 20, 2023
1 parent d386e3c commit 32e283f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 31 deletions.
3 changes: 1 addition & 2 deletions conda-recipe/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ export CXX=icpx

# new llvm-spirv location
# starting from dpcpp_impl_linux-64=2022.0.0=intel_3610
PATH=$CONDA_PREFIX/bin-llvm:$PATH
export PATH
export PATH=$CONDA_PREFIX/bin-llvm:$PATH

SKBUILD_ARGS=(-G Ninja -- -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON)

Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/core/runtime/_dpexrt_python.c
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,10 @@ static PyObject *build_c_helpers_dict(void)
&DPEXRT_nrt_acquire_meminfo_and_schedule_release);
_declpointer("DPEXRT_build_or_get_kernel", &DPEXRT_build_or_get_kernel);
_declpointer("DPEXRT_kernel_cache_size", &DPEXRT_kernel_cache_size);
_declpointer("NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_interval",
&NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_interval);
_declpointer("NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_affine_interval",
&NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_affine_interval);

#undef _declpointer
return dct;
Expand Down
22 changes: 11 additions & 11 deletions numba_dpex/dpnp_iface/array_interval_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)


def _is_type(args, numba_type, basic_type=None):
def _match_type(args, numba_type, basic_type=None):
"""Checks if one of the paramteres in `args` is any of the
`numba.core.types.scalars.*` type.
Expand Down Expand Up @@ -75,11 +75,11 @@ def _parse_dtype_from_range(range):
Returns:
numba.core.types.scalars.*: Infered `dtype` for the output tensor.
"""
if _is_type(range, Complex):
if _match_type(range, Complex):
return numba.from_dtype(dpnp.complex_)
elif _is_type(range, Float, float):
elif _match_type(range, Float, float):
return numba.from_dtype(dpnp.float)
elif _is_type(range, Integer, int):
elif _match_type(range, Integer, int):
return numba.from_dtype(dpnp.int)
else:
msg = (
Expand Down Expand Up @@ -365,10 +365,10 @@ def impl_dpnp_arange(
Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
ty_start (numba.core.types.scalars.Integer): Numba type for the start
of the interval.
ty_stop (numba.core.types.scalars.Integer): Numba type for the end
of the interval.
ty_start (numba.core.types.scalars.*): Numba type for the start of the
interval.
ty_stop (numba.core.types.scalars.*): Numba type for the end of the
interval.
ty_step (numba.core.types.scalars.Integer): Numba type for the step
of the interval.
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
Expand Down Expand Up @@ -606,10 +606,10 @@ def ol_dpnp_arange(
)

if is_nonelike(stop):
start = 0 if _is_type([start], Integer, int) else 0.0
stop = 1 if _is_type([start], Integer, int) else 1.0
start = 0 if _match_type([start], Integer, int) else 0.0
stop = 1 if _match_type([start], Integer, int) else 1.0
if is_nonelike(step):
step = 1 if _is_type([start], Integer, int) else 1.0
step = 1 if _match_type([start], Integer, int) else 1.0

_dtype = (
_parse_dtype(dtype)
Expand Down
16 changes: 16 additions & 0 deletions numba_dpex/tests/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,19 @@ def skip_if_dtype_not_supported(dt, q_or_dev):
pytest.skip(
f"{dev.name} does not support half precision floating point type"
)


def get_xfail_test(param, reason):
"""Generate an X-fail test `pytest` parameter.
Args:
param (list): A `list` of valid parameters.
reason (str): A `str` describing the reason for failure.
Returns:
pytest.param: A `pytest.param` parameter.
"""
return pytest.param(
param,
marks=pytest.mark.xfail(reason=reason),
)
19 changes: 1 addition & 18 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,7 @@
import pytest

from numba_dpex import dpjit
from numba_dpex.tests._helper import get_all_dtypes


def get_xfail_test(param, reason):
"""Generate an X-fail test `pytest` parameter.
Args:
param (list): A `list` of valid parameters.
reason (str): A `str` describing the reason for failure.
Returns:
pytest.param: A `pytest.param` parameter.
"""
return pytest.param(
param,
marks=pytest.mark.xfail(reason=reason),
)

from numba_dpex.tests._helper import get_all_dtypes, get_xfail_test

# Get all dtypes, except bool, float16 and complex
dtypes = get_all_dtypes(
Expand Down

0 comments on commit 32e283f

Please sign in to comment.