Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental kernel decorator to numba jit_registry #1205

Merged
merged 1 commit into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect

from numba.core import sigutils
from numba.core.target_extension import jit_registry, target_registry

from .kernel_dispatcher import KernelDispatcher

Expand Down Expand Up @@ -78,3 +79,6 @@ def _specialized_kernel_dispatcher(pyfunc):
"the return type as void explicitly."
)
return _kernel_dispatcher(func)


jit_registry[target_registry["dpex_kernel"]] = kernel
11 changes: 6 additions & 5 deletions numba_dpex/tests/experimental/test_exec_queue_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dpctl
import dpnp
import pytest
from numba.core import config

import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
Expand Down Expand Up @@ -34,14 +35,16 @@ def test_successful_execution_queue_inference():
c = dpnp.zeros_like(a, sycl_queue=q)
r = Range(100)

# FIXME: This test fails unexpectedly if the NUMBA_CAPTURED_ERRORS is set
# to "new_style".
# Refer: https://github.com/IntelPython/numba-dpex/issues/1195
current_captured_error_style = config.CAPTURED_ERRORS
config.CAPTURED_ERRORS = "new_style"

try:
exp_dpex.call_kernel(add, r, a, b, c)
except:
pytest.fail("Unexpected error when calling kernel")

config.CAPTURED_ERRORS = current_captured_error_style

assert c[0] == b[0] + a[0]


Expand All @@ -59,8 +62,6 @@ def test_execution_queue_inference_error():
c = dpnp.zeros_like(a, sycl_queue=q1)
r = Range(100)

from numba.core import config

current_captured_error_style = config.CAPTURED_ERRORS
config.CAPTURED_ERRORS = "new_style"

Expand Down
Loading