Skip to content

Commit

Permalink
WIP unit tests...
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Dec 8, 2023
1 parent 1ade40d commit 7d894a5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
3 changes: 2 additions & 1 deletion numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from numba.core.imputils import Registry

from .decorators import kernel
from .decorators import device_func, kernel
from .kernel_dispatcher import KernelDispatcher
from .launcher import call_kernel, call_kernel_async
from .literal_intenum_type import IntEnumLiteral
Expand All @@ -28,6 +28,7 @@ def dpex_dispatcher_const(context):


__all__ = [
"device_func",
"kernel",
"call_kernel",
"call_kernel_async",
Expand Down
2 changes: 2 additions & 0 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
they are stable enough to be migrated to DpexKernelTargetContext.
"""

allow_dynamic_globals = True

def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
super().__init__(typingctx, target)
self.data_model_manager = exp_dmm
Expand Down
29 changes: 26 additions & 3 deletions numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
#
# SPDX-License-Identifier: Apache-2.0

import dpctl
import dpnp
from numba.core import types

import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
from numba_dpex import DpctlSyclQueue, DpnpNdArray, Range, int64
from numba_dpex.experimental.flag_enum import FlagEnum


class MockFlags(FlagEnum):
FLAG1 = 100
FLAG2 = 200
FLAG1 = 1
FLAG2 = 2


@exp_dpex.kernel(
Expand All @@ -34,3 +36,24 @@ def test_compilation_of_flag_enum():
assert a[1] == MockFlags.FLAG2
for idx in range(2, 9):
assert a[idx] == 1


def test_compilation_as_literal_constant():
@exp_dpex.device_func
def bitwise_or_flags(flag1, flag2):
return flag1 | flag2

def pass_flags_to_func(a):
f1 = MockFlags.FLAG1
f2 = MockFlags.FLAG2
a[0] = bitwise_or_flags(f1, f2)

queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty)

disp = exp_dpex.kernel(pass_flags_to_func)
disp.compile(kernel_sig)
kcres = disp.overloads[kernel_sig.args]
llvm_ir_mod = kcres.library._final_module
print(llvm_ir_mod)

0 comments on commit 7d894a5

Please sign in to comment.