Skip to content

Commit

Permalink
Set inline threshold default value to 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 15, 2024
1 parent ee1c731 commit f846dd0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def __getattr__(name):

DPEX_OPT = _readenv("NUMBA_DPEX_OPT", int, 2)

INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, None)
INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, 2)

USE_MLIR = _readenv("NUMBA_DPEX_USE_MLIR", int, 0)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_codegen_with_max_inline_threshold():
and pipeline to compile both host callable "kernels" and device-only
"device_func" functions.
Unless the inline_threshold is set to 3, the `spir_func` function is not
Unless the inline_threshold is set to >0, the `spir_func` function is not
inlined into the wrapper function. The test checks if the `spir_func`
function is fully inlined into the wrapper. The test is rather rudimentary
and only checks the count of function in the generated module.
Expand All @@ -39,7 +39,7 @@ def test_codegen_with_max_inline_threshold():
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)

disp = dpex_exp.kernel(inline_threshold=3)(kernel_func)
disp = dpex_exp.kernel(inline_threshold=1)(kernel_func)
disp.compile(kernel_sig)
kcres = disp.overloads[kernel_sig.args]
llvm_ir_mod = kcres.library._final_module
Expand All @@ -60,7 +60,7 @@ def test_codegen_without_max_inline_threshold():
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)

disp = dpex_exp.kernel(kernel_func)
disp = dpex_exp.kernel(inline_threshold=0)(kernel_func)
disp.compile(kernel_sig)
kcres = disp.overloads[kernel_sig.args]
llvm_ir_mod = kcres.library._final_module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def pass_flags_to_func(a):
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
kernel_sig = types.void(i64arr_ty)

disp = exp_dpex.kernel(pass_flags_to_func)
disp = exp_dpex.kernel(inline_threshold=0)(pass_flags_to_func)
disp.compile(kernel_sig)
kcres = disp.overloads[kernel_sig.args]
llvm_ir_mod = kcres.library._final_module.__str__()
Expand Down

0 comments on commit f846dd0

Please sign in to comment.