Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port the parfor range kernel template to new API. #1416

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def typing_context(self):
"""
return self._toplevel_typing_context

@property
def target_name(self):
return self._target_name


class DpexTarget(TargetDescriptor):
"""
Expand Down
33 changes: 21 additions & 12 deletions numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
rename_labels,
replace_var_names,
)
from numba.core.target_extension import target_override
from numba.core.typing import signature
from numba.parfors import parfor

from numba_dpex.core import config
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.kernel_api_impl.spirv import spirv_generator

from ..descriptor import dpex_kernel_target
Expand Down Expand Up @@ -66,18 +68,18 @@ def _print_body(body_dict):
def _compile_kernel_parfor(
sycl_queue, kernel_name, func_ir, argtypes, debug=False
):

cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
with target_override(dpex_kernel_target.target_context.target_name):
cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
cres.library.inline_threshold = config.INLINE_THRESHOLD
cres.library._optimize_final_module()
func = cres.library.get_function(cres.fndesc.llvm_func_name)
Expand Down Expand Up @@ -420,6 +422,13 @@ def create_kernel_for_parfor(
print("kernel_ir after remove dead")
kernel_ir.dump()

# The first argument to a range kernel is a kernel_api.Item object. The
# ``Item`` object is used by the kernel_api.spirv backend to generate the
# correct SPIR-V indexing instructions. Since, the argument is not something
# available originally in the kernel_param_types, we add it at this point to
# make sure the kernel signature matches the actual generated code.
ty_item = ItemType(parfor_dim)
kernel_param_types = (ty_item, *kernel_param_types)
kernel_sig = signature(types.none, *kernel_param_types)

if config.DEBUG_ARRAY_OPT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def _generate_kernel_stub_as_string(self):

# Create the dpex kernel function.
kernel_txt += "def " + self._kernel_name
kernel_txt += "(" + (", ".join(self._kernel_params)) + "):\n"
kernel_txt += "(item, " + (", ".join(self._kernel_params)) + "):\n"
global_id_dim = 0
for_loop_dim = self._kernel_rank
global_id_dim = self._kernel_rank

for dim in range(global_id_dim):
dimstr = str(dim)
kernel_txt += (
f" {self._ivar_names[dim]} = dpex.get_global_id({dimstr})\n"
f" {self._ivar_names[dim]} = item.get_id({dimstr})\n"
)

for dim in range(global_id_dim, for_loop_dim):
Expand Down
Loading