diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1753b7a3fe..2411834df1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: - id: pylint name: pylint entry: pylint - files: ^numba_dpex/experimental + files: ^numba_dpex/experimental|^numba_dpex/core/utils/kernel_launcher.py language: system types: [python] require_serial: true diff --git a/numba_dpex/core/parfors/kernel_builder.py b/numba_dpex/core/parfors/kernel_builder.py index c27ae43b3d..931357bfeb 100644 --- a/numba_dpex/core/parfors/kernel_builder.py +++ b/numba_dpex/core/parfors/kernel_builder.py @@ -40,7 +40,7 @@ def __init__( signature, kernel_args, kernel_arg_types, - queue, + queue: dpctl.SyclQueue, ): self.name = name self.kernel = kernel @@ -244,7 +244,7 @@ def create_kernel_for_parfor( has_aliases, races, parfor_outputs, -): +) -> ParforKernel: """ Creates a numba_dpex.kernel function for a parfor node. @@ -422,7 +422,7 @@ def create_kernel_for_parfor( # arrays are on same device. We can take the queue from the first input # array and use that to compile the kernel. - exec_queue = None + exec_queue: dpctl.SyclQueue = None for arg in parfor_args: obj = typemap[arg] diff --git a/numba_dpex/core/parfors/parfor_lowerer.py b/numba_dpex/core/parfors/parfor_lowerer.py index eaead837a7..10e634cb63 100644 --- a/numba_dpex/core/parfors/parfor_lowerer.py +++ b/numba_dpex/core/parfors/parfor_lowerer.py @@ -13,26 +13,24 @@ ) from numba_dpex import config -from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm from numba_dpex.core.parfors.reduction_helper import ( ReductionHelper, ReductionKernelVariables, ) from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder +from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl +from numba_dpex.core.datamodel.models import ( + dpex_data_model_manager as kernel_dmm, +) from ..exceptions import UnsupportedParforError from ..types.dpnp_ndarray_type import DpnpNdArray -from .kernel_builder import create_kernel_for_parfor +from .kernel_builder import ParforKernel, create_kernel_for_parfor from .reduction_kernel_builder import ( create_reduction_main_kernel_for_parfor, create_reduction_remainder_kernel_for_parfor, ) -_KernelArgs = namedtuple( - "_KernelArgs", - ["num_flattened_args", "arg_vals", "arg_types"], -) - # A global list of kernels to keep the objects alive indefinitely. keep_alive_kernels = [] @@ -68,11 +66,8 @@ def _getvar(lowerer, x): var_val = lowerer.varmap[x] if var_val: - if not isinstance(var_val.type, llvmir.PointerType): - with lowerer.builder.goto_entry_block(): - var_val_ptr = lowerer.builder.alloca(var_val.type) - lowerer.builder.store(var_val, var_val_ptr) - return var_val_ptr + if isinstance(var_val.type, llvmir.PointerType): + return lowerer.builder.load(var_val) else: return var_val else: @@ -91,75 +86,15 @@ class ParforLowerImpl: for a parfor and submits it to a queue. """ - def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder): - """Creates local variables for all the arguments and the argument types - that are passes to the kernel function. - - Args: - kernel_fn: Kernel function to be launched. - lowerer: The Numba lowerer used to generate the LLVM IR - - Raises: - AssertionError: If the LLVM IR Value for an argument defined in - Numba IR is not found. - """ - num_flattened_args = 0 - - # Compute number of args to be passed to the kernel. Note that the - # actual number of kernel arguments is greater than the count of - # kernel_fn.kernel_args as arrays get flattened. - for arg_type in kernel_fn.kernel_arg_types: - if isinstance(arg_type, DpnpNdArray): - datamodel = dpex_dmm.lookup(arg_type) - num_flattened_args += datamodel.flattened_field_count - elif arg_type == types.complex64 or arg_type == types.complex128: - num_flattened_args += 2 - else: - num_flattened_args += 1 - - # Create LLVM values for the kernel args list and kernel arg types list - args_list = kernel_builder.allocate_kernel_arg_array(num_flattened_args) - args_ty_list = kernel_builder.allocate_kernel_arg_ty_array( - num_flattened_args - ) - callargs_ptrs = [] - for arg in kernel_fn.kernel_args: - callargs_ptrs.append(_getvar(lowerer, arg)) - - kernel_builder.populate_kernel_args_and_args_ty_arrays( - kernel_argtys=kernel_fn.kernel_arg_types, - callargs_ptrs=callargs_ptrs, - args_list=args_list, - args_ty_list=args_ty_list, - ) - - return _KernelArgs( - num_flattened_args=num_flattened_args, - arg_vals=args_list, - arg_types=args_ty_list, - ) - - def _submit_parfor_kernel( + def _loop_ranges( self, lowerer, - kernel_fn, loop_ranges, ): - """ - Adds a call to submit a kernel function into the function body of the - current Numba JIT compiled function. - """ - # Ensure that the Python arguments are kept alive for the duration of - # the kernel execution - keep_alive_kernels.append(kernel_fn.kernel) - kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder) - - ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue) - args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder) - # Create a global range over which to submit the kernel based on the # loop_ranges of the parfor global_range = [] + # SYCL ranges can have at max 3 dimension. If the parfor is of a higher # dimension then the indexing for the higher dimensions is done inside # the kernel. @@ -173,48 +108,19 @@ def _submit_parfor_kernel( "non-unit strides are not yet supported." ) global_range.append(stop) - + # For now the local_range is always an empty list as numba_dpex always + # submits kernels generated for parfor nodes as range kernels. + # The provision is kept here if in future there is newer functionality + # to submit these kernels as ndrange. local_range = [] - kernel_ref_addr = kernel_fn.kernel.addressof_ref() - kernel_ref = lowerer.builder.inttoptr( - lowerer.context.get_constant(types.uintp, kernel_ref_addr), - cgutils.voidptr_t, - ) - curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref) - - # Submit a synchronous kernel - kernel_builder.submit_sycl_kernel( - sycl_kernel_ref=kernel_ref, - sycl_queue_ref=curr_queue_ref, - total_kernel_args=args.num_flattened_args, - arg_list=args.arg_vals, - arg_ty_list=args.arg_types, - global_range=global_range, - local_range=local_range, - ) - - # At this point we can free the DPCTLSyclQueueRef (curr_queue) - kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref) + return global_range, local_range - def _submit_reduction_main_parfor_kernel( + def _reduction_ranges( self, lowerer, - kernel_fn, reductionHelper=None, ): - """ - Adds a call to submit the main kernel of a parfor reduction into the - function body of the current Numba JIT compiled function. - """ - # Ensure that the Python arguments are kept alive for the duration of - # the kernel execution - keep_alive_kernels.append(kernel_fn.kernel) - kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder) - - ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue) - - args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder) # Create a global range over which to submit the kernel based on the # loop_ranges of the parfor global_range = [] @@ -228,75 +134,63 @@ def _submit_reduction_main_parfor_kernel( _load_range(lowerer, reductionHelper.work_group_size) ) - kernel_ref_addr = kernel_fn.kernel.addressof_ref() - kernel_ref = lowerer.builder.inttoptr( - lowerer.context.get_constant(types.uintp, kernel_ref_addr), - cgutils.voidptr_t, - ) - curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref) - - # Submit a synchronous kernel - kernel_builder.submit_sycl_kernel( - sycl_kernel_ref=kernel_ref, - sycl_queue_ref=curr_queue_ref, - total_kernel_args=args.num_flattened_args, - arg_list=args.arg_vals, - arg_ty_list=args.arg_types, - global_range=global_range, - local_range=local_range, - ) + return global_range, local_range - # At this point we can free the DPCTLSyclQueueRef (curr_queue) - kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref) + def _remainder_ranges(self, lowerer): + # Create a global range over which to submit the kernel based on the + # loop_ranges of the parfor + global_range = [] - def _submit_reduction_remainder_parfor_kernel( + stop = _load_range(lowerer, 1) + + global_range.append(stop) + + local_range = [] + + return global_range, local_range + + def _submit_parfor_kernel( self, lowerer, - kernel_fn, + kernel_fn: ParforKernel, + global_range, + local_range, ): """ - Adds a call to submit the remainder kernel of a parfor reduction into - the function body of the current Numba JIT compiled function. + Adds a call to submit a kernel function into the function body of the + current Numba JIT compiled function. """ # Ensure that the Python arguments are kept alive for the duration of # the kernel execution keep_alive_kernels.append(kernel_fn.kernel) + kl_builder = KernelLaunchIRBuilder( + lowerer.context, lowerer.builder, kernel_dmm + ) - kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder) - - ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue) - - args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder) - # Create a global range over which to submit the kernel based on the - # loop_ranges of the parfor - global_range = [] - - stop = _load_range(lowerer, 1) + queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue) - global_range.append(stop) - - local_range = [] + kernel_args = [] + for arg in kernel_fn.kernel_args: + kernel_args.append(_getvar(lowerer, arg)) kernel_ref_addr = kernel_fn.kernel.addressof_ref() kernel_ref = lowerer.builder.inttoptr( lowerer.context.get_constant(types.uintp, kernel_ref_addr), cgutils.voidptr_t, ) - curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref) - - # Submit a synchronous kernel - kernel_builder.submit_sycl_kernel( - sycl_kernel_ref=kernel_ref, - sycl_queue_ref=curr_queue_ref, - total_kernel_args=args.num_flattened_args, - arg_list=args.arg_vals, - arg_ty_list=args.arg_types, - global_range=global_range, - local_range=local_range, + + kl_builder.set_kernel(kernel_ref) + kl_builder.set_queue(queue_ref) + kl_builder.set_range(global_range, local_range) + kl_builder.set_arguments( + kernel_fn.kernel_arg_types, kernel_args=kernel_args ) + kl_builder.set_dependant_event_list(dep_events=[]) + event_ref = kl_builder.submit() - # At this point we can free the DPCTLSyclQueueRef (curr_queue) - kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref) + sycl.dpctl_event_wait(lowerer.builder, event_ref) + sycl.dpctl_event_delete(lowerer.builder, event_ref) + sycl.dpctl_queue_delete(lowerer.builder, queue_ref) def _reduction_codegen( self, @@ -360,10 +254,15 @@ def _reduction_codegen( parfor_reddict, ) - self._submit_reduction_main_parfor_kernel( + global_range, local_range = self._reduction_ranges( + lowerer, reductionHelperList[0] + ) + + self._submit_parfor_kernel( lowerer, parfor_kernel, - reductionHelperList[0], + global_range, + local_range, ) parfor_kernel = create_reduction_remainder_kernel_for_parfor( @@ -376,9 +275,13 @@ def _reduction_codegen( reductionHelperList, ) - self._submit_reduction_remainder_parfor_kernel( + global_range, local_range = self._remainder_ranges(lowerer) + + self._submit_parfor_kernel( lowerer, parfor_kernel, + global_range, + local_range, ) reductionKernelVar.copy_final_sum_to_host(parfor_kernel) @@ -492,11 +395,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor): # FIXME: Make the exception more informative raise UnsupportedParforError + global_range, local_range = self._loop_ranges(lowerer, loop_ranges) + # Finally submit the kernel self._submit_parfor_kernel( lowerer, parfor_kernel, - loop_ranges, + global_range, + local_range, ) # TODO: free the kernel at this point diff --git a/numba_dpex/core/parfors/reduction_helper.py b/numba_dpex/core/parfors/reduction_helper.py index 93d9a147c0..f47f29009f 100644 --- a/numba_dpex/core/parfors/reduction_helper.py +++ b/numba_dpex/core/parfors/reduction_helper.py @@ -17,6 +17,9 @@ from numba.parfors.parfor_lowering_utils import ParforLoweringBuilder from numba_dpex import utils +from numba_dpex.core.datamodel.models import ( + dpex_data_model_manager as kernel_dmm, +) from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl @@ -395,11 +398,13 @@ def work_group_size(self): def copy_final_sum_to_host(self, parfor_kernel): lowerer = self.lowerer - ir_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder) + kl_builder = KernelLaunchIRBuilder( + lowerer.context, lowerer.builder, kernel_dmm + ) # Create a local variable storing a pointer to a DPCTLSyclQueueRef # pointer. - curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue) + queue_ref = kl_builder.get_queue(exec_queue=parfor_kernel.queue) builder = lowerer.builder context = lowerer.context @@ -433,7 +438,7 @@ def copy_final_sum_to_host(self, parfor_kernel): ) args = [ - builder.load(curr_queue), + queue_ref, dest, src, builder.load(item_size), @@ -443,4 +448,4 @@ def copy_final_sum_to_host(self, parfor_kernel): sycl.dpctl_event_wait(builder, event_ref) sycl.dpctl_event_delete(builder, event_ref) - ir_builder.free_queue(ptr_to_sycl_queue_ref=curr_queue) + sycl.dpctl_queue_delete(builder, queue_ref) diff --git a/numba_dpex/core/utils/kernel_launcher.py b/numba_dpex/core/utils/kernel_launcher.py index 416bed3f84..42512ea7ff 100644 --- a/numba_dpex/core/utils/kernel_launcher.py +++ b/numba_dpex/core/utils/kernel_launcher.py @@ -2,8 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Module that contains numba style wrapper around sycl kernel submit.""" + +from dataclasses import dataclass + +import dpctl from llvmlite import ir as llvmir +from llvmlite.ir.builder import IRBuilder from numba.core import cgutils, types +from numba.core.cpu import CPUContext +from numba.core.datamodel import DataModelManager from numba_dpex import config, utils from numba_dpex.core.runtime.context import DpexRTContext @@ -11,6 +19,48 @@ from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum +MAX_SIZE_OF_SYCL_RANGE = 3 + + +@dataclass +class _KernelLaunchIRArguments: # pylint: disable=too-many-instance-attributes + """List of kernel launch arguments used in sycl.dpctl_queue_submit_range and + sycl.dpctl_queue_submit_ndrange.""" + + sycl_kernel_ref: llvmir.Instruction = None + sycl_queue_ref: llvmir.Instruction = None + arg_list: llvmir.Instruction = None + arg_ty_list: llvmir.Instruction = None + total_kernel_args: llvmir.Instruction = None + global_range: llvmir.Instruction = None + local_range: llvmir.Instruction = None + range_size: llvmir.Instruction = None + dep_events: llvmir.Instruction = None + dep_events_len: llvmir.Instruction = None + + def to_list(self): + """Returns list of arguments in the right order to pass to + sycl.dpctl_queue_submit_range or sycl.dpctl_queue_submit_ndrange.""" + res = [ + self.sycl_kernel_ref, + self.sycl_queue_ref, + self.arg_list, + self.arg_ty_list, + self.total_kernel_args, + self.global_range, + ] + + if self.local_range is not None: + res += [self.local_range] + + res += [ + self.range_size, + self.dep_events, + self.dep_events_len, + ] + + return res + class KernelLaunchIRBuilder: """ @@ -21,7 +71,12 @@ class KernelLaunchIRBuilder: for submitting kernels. The LLVM Values that """ - def __init__(self, context, builder): + def __init__( + self, + context: CPUContext, + builder: IRBuilder, + kernel_dmm: DataModelManager, + ): """Create a KernelLauncher for the specified kernel. Args: @@ -32,6 +87,8 @@ def __init__(self, context, builder): self.context = context self.builder = builder self.rtctx = DpexRTContext(self.context) + self.arguments = _KernelLaunchIRArguments() + self.kernel_dmm = kernel_dmm def _build_nullptr(self): """Builds the LLVM IR to represent a null pointer. @@ -44,7 +101,7 @@ def _build_nullptr(self): zero, utils.get_llvm_type(context=self.context, type=types.voidptr) ) - def _build_array_attr_arg( + def _build_array_attr_arg( # pylint: disable=too-many-arguments self, array_val, array_attr_pos, @@ -68,7 +125,7 @@ def _build_array_attr_arg( ): array_attr = self.builder.load(array_attr) - self.build_arg( + self._build_arg( val=array_attr, ty=array_attr_ty, arg_list=arg_list, @@ -76,7 +133,7 @@ def _build_array_attr_arg( arg_num=arg_num, ) - def _build_unituple_member_arg( + def _build_unituple_member_arg( # pylint: disable=too-many-arguments self, array_val, array_attr_pos, ndims, arg_list, args_ty_list, arg_num ): array_attr = self.builder.gep( @@ -97,7 +154,9 @@ def _build_unituple_member_arg( arg_num=arg_num + ndim, ) - def build_arg(self, val, ty, arg_list, args_ty_list, arg_num): + def _build_arg( + self, val, ty, arg_list, args_ty_list, arg_num + ): # pylint: disable=too-many-arguments """Stores the kernel arguments and the kernel argument types into arrays that will be passed to DPCTLQueue_SubmitRange. @@ -128,7 +187,9 @@ def build_arg(self, val, ty, arg_list, args_ty_list, arg_num): numba_type_to_dpctl_typenum(self.context, ty), kernel_arg_ty_dst ) - def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num): + def _build_complex_arg( + self, val, ty, arg_list, args_ty_list, arg_num + ): # pylint: disable=too-many-arguments """Creates a list of LLVM Values for an unpacked complex kernel argument. """ @@ -151,11 +212,10 @@ def build_complex_arg(self, val, ty, arg_list, args_ty_list, arg_num): ) arg_num += 1 - def build_array_arg( + def _build_array_arg( # pylint: disable=too-many-arguments self, array_val, array_data_model, - array_rank, arg_list, args_ty_list, arg_num, @@ -168,7 +228,7 @@ def build_array_arg( """ # Argument 1: Null pointer for the NRT_MemInfo attribute of the array nullptr = self._build_nullptr() - self.build_arg( + self._build_arg( val=nullptr, ty=types.int64, arg_list=arg_list, @@ -178,7 +238,7 @@ def build_array_arg( arg_num += 1 # Argument 2: Null pointer for the Parent attribute of the array nullptr = self._build_nullptr() - self.build_arg( + self._build_arg( val=nullptr, ty=types.int64, arg_list=arg_list, @@ -218,7 +278,7 @@ def build_array_arg( arg_num += 1 # Argument sycl_queue: as the queue pointer is not to be used in a # kernel we always pass in a nullptr - self.build_arg( + self._build_arg( val=nullptr, ty=types.int64, arg_list=arg_list, @@ -249,7 +309,8 @@ def build_array_arg( ) arg_num += stride_member.count - def get_queue(self, exec_queue): + # TODO: remove, not part of the builder + def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction: """Allocates memory on the stack to store a DPCTLSyclQueueRef. Returns: A LLVM Value storing the pointer to the SYCL queue created @@ -273,20 +334,9 @@ def get_queue(self, exec_queue): ), sycl_queue_val, ) - return sycl_queue_val - - def free_queue(self, ptr_to_sycl_queue_ref): - """ - Frees the ``DPCTLSyclQueueRef`` pointer that was used to launch the - kernels. - - Args: - sycl_queue_val: The SYCL queue pointer to be freed. - """ - qref = self.builder.load(ptr_to_sycl_queue_ref) - sycl.dpctl_queue_delete(self.builder, qref) + return self.builder.load(sycl_queue_val) - def allocate_kernel_arg_array(self, num_kernel_args): + def _allocate_kernel_arg_array(self, num_kernel_args): """Allocates an array to store the LLVM Value for every kernel argument. Args: @@ -298,13 +348,13 @@ def allocate_kernel_arg_array(self, num_kernel_args): """ args_list = cgutils.alloca_once( self.builder, - utils.get_llvm_type(context=self.context, type=types.voidptr), + utils.LLVMTypes.byte_ptr_t, size=self.context.get_constant(types.uintp, num_kernel_args), ) return args_list - def allocate_kernel_arg_ty_array(self, num_kernel_args): + def _allocate_kernel_arg_ty_array(self, num_kernel_args): """Allocates an array to store the LLVM Value for the typenum for every kernel argument. @@ -335,8 +385,6 @@ def _create_sycl_range(self, idx_range): intp_ptr_t = utils.get_llvm_ptr_type(intp_t) num_dim = len(idx_range) - MAX_SIZE_OF_SYCL_RANGE = 3 - # form the global range range_list = cgutils.alloca_once( self.builder, @@ -361,118 +409,112 @@ def _create_sycl_range(self, idx_range): return self.builder.bitcast(range_list, intp_ptr_t) - def submit_kernel( + def set_kernel(self, sycl_kernel_ref: llvmir.Instruction): + """Sets kernel to the argument list.""" + self.arguments.sycl_kernel_ref = sycl_kernel_ref + + def set_queue(self, sycl_queue_ref: llvmir.Instruction): + """Sets queue to the argument list.""" + self.arguments.sycl_queue_ref = sycl_queue_ref + + def set_range( + self, + global_range: list, + local_range: list = None, + ): + """Sets global and local range if provided to the argument list.""" + self.arguments.global_range = self._create_sycl_range(global_range) + if local_range is not None and len(local_range) > 0: + self.arguments.local_range = self._create_sycl_range(local_range) + self.arguments.range_size = self.context.get_constant( + types.uintp, len(global_range) + ) + + def set_arguments( self, - kernel_ref: llvmir.CallInstr, - queue_ref: llvmir.PointerType, - kernel_args: list, ty_kernel_args: list, - global_range_extents: list, - local_range_extents: list, + kernel_args: list, ): + """Sets flattened kernel args, kernel arg types and number of those + arguments to the argument list.""" if config.DEBUG_KERNEL_LAUNCHER: cgutils.printf( self.builder, "DPEX-DEBUG: Populating kernel args and arg type arrays.\n", ) - num_flattened_kernel_args = self.get_num_flattened_kernel_args( + num_flattened_kernel_args = self._get_num_flattened_kernel_args( kernel_argtys=ty_kernel_args, ) # Create LLVM values for the kernel args list and kernel arg types list - args_list = self.allocate_kernel_arg_array(num_flattened_kernel_args) + args_list = self._allocate_kernel_arg_array(num_flattened_kernel_args) - args_ty_list = self.allocate_kernel_arg_ty_array( + args_ty_list = self._allocate_kernel_arg_ty_array( num_flattened_kernel_args ) kernel_args_ptrs = [] for arg in kernel_args: - ptr = self.builder.alloca(arg.type) + with self.builder.goto_entry_block(): + ptr = self.builder.alloca(arg.type) self.builder.store(arg, ptr) kernel_args_ptrs.append(ptr) # Populate the args_list and the args_ty_list LLVM arrays - self.populate_kernel_args_and_args_ty_arrays( + self._populate_kernel_args_and_args_ty_arrays( callargs_ptrs=kernel_args_ptrs, kernel_argtys=ty_kernel_args, args_list=args_list, args_ty_list=args_ty_list, ) - if config.DEBUG_KERNEL_LAUNCHER: - cgutils.printf(self._builder, "DPEX-DEBUG: Submit kernel.\n") - - return self.submit_sycl_kernel( - sycl_kernel_ref=kernel_ref, - sycl_queue_ref=queue_ref, - total_kernel_args=num_flattened_kernel_args, - arg_list=args_list, - arg_ty_list=args_ty_list, - global_range=global_range_extents, - local_range=local_range_extents, - wait_before_return=False, + self.arguments.arg_list = args_list + self.arguments.arg_ty_list = args_ty_list + self.arguments.total_kernel_args = self.context.get_constant( + types.uintp, num_flattened_kernel_args ) - def submit_sycl_kernel( - self, - sycl_kernel_ref, - sycl_queue_ref, - total_kernel_args, - arg_list, - arg_ty_list, - global_range, - local_range=[], - wait_before_return=True, - ) -> llvmir.PointerType(llvmir.IntType(8)): - """ - Submits the kernel to the specified queue, waits by default. - """ - eref = None - gr = self._create_sycl_range(global_range) - args1 = [ - sycl_kernel_ref, - sycl_queue_ref, - arg_list, - arg_ty_list, - self.context.get_constant(types.uintp, total_kernel_args), - gr, - ] - args2 = [ - self.context.get_constant(types.uintp, len(global_range)), - self.builder.bitcast( - utils.create_null_ptr( - builder=self.builder, context=self.context - ), - utils.get_llvm_type(context=self.context, type=types.voidptr), - ), - self.context.get_constant(types.uintp, 0), - ] - args = [] - if len(local_range) == 0: - args = args1 + args2 + def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]): + """Sets dependant events to the argument list.""" + if self.arguments.dep_events is not None: + return + + if len(dep_events) > 0: + # TODO: implement for non zero input + raise NotImplementedError + + self.arguments.dep_events = self.builder.bitcast( + utils.create_null_ptr(builder=self.builder, context=self.context), + utils.get_llvm_type(context=self.context, type=types.voidptr), + ) + self.arguments.dep_events_len = self.context.get_constant( + types.uintp, 0 + ) + + def submit(self) -> llvmir.Instruction: + """Submits kernel by calling sycl.dpctl_queue_submit_range or + sycl.dpctl_queue_submit_ndrange. Must be called after all arguments + set.""" + args = self.arguments.to_list() + + if self.arguments.local_range is None: eref = sycl.dpctl_queue_submit_range(self.builder, *args) else: - lr = self._create_sycl_range(local_range) - args = args1 + [lr] + args2 eref = sycl.dpctl_queue_submit_ndrange(self.builder, *args) - if wait_before_return: - sycl.dpctl_event_wait(self.builder, eref) - sycl.dpctl_event_delete(self.builder, eref) - return None - else: - return eref + return eref - def get_num_flattened_kernel_args( + def _get_num_flattened_kernel_args( self, kernel_argtys: tuple[types.Type, ...], - ): + ) -> int: + """Returns number of flattened arguments based on the numba types. + flattens dpnp arrays and complex values.""" num_flattened_kernel_args = 0 for arg_type in kernel_argtys: if isinstance(arg_type, DpnpNdArray): - datamodel = self.context.data_model_manager.lookup(arg_type) + datamodel = self.kernel_dmm.lookup(arg_type) num_flattened_kernel_args += datamodel.flattened_field_count elif arg_type in [types.complex64, types.complex128]: num_flattened_kernel_args += 2 @@ -481,7 +523,7 @@ def get_num_flattened_kernel_args( return num_flattened_kernel_args - def populate_kernel_args_and_args_ty_arrays( + def _populate_kernel_args_and_args_ty_arrays( self, kernel_argtys, callargs_ptrs, @@ -492,11 +534,10 @@ def populate_kernel_args_and_args_ty_arrays( for arg_num, argtype in enumerate(kernel_argtys): llvm_val = callargs_ptrs[arg_num] if isinstance(argtype, DpnpNdArray): - datamodel = self.context.data_model_manager.lookup(argtype) - self.build_array_arg( + datamodel = self.kernel_dmm.lookup(argtype) + self._build_array_arg( array_val=llvm_val, array_data_model=datamodel, - array_rank=argtype.ndim, arg_list=args_list, args_ty_list=args_ty_list, arg_num=kernel_arg_num, @@ -504,7 +545,7 @@ def populate_kernel_args_and_args_ty_arrays( kernel_arg_num += datamodel.flattened_field_count else: if argtype == types.complex64: - self.build_complex_arg( + self._build_complex_arg( llvm_val, types.float32, args_list, @@ -513,7 +554,7 @@ def populate_kernel_args_and_args_ty_arrays( ) kernel_arg_num += 2 elif argtype == types.complex128: - self.build_complex_arg( + self._build_complex_arg( llvm_val, types.float64, args_list, @@ -522,7 +563,7 @@ def populate_kernel_args_and_args_ty_arrays( ) kernel_arg_num += 2 else: - self.build_arg( + self._build_arg( llvm_val, argtype, args_list, diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index 63389fcfb5..9d6ee3213a 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -19,6 +19,7 @@ from numba_dpex import config, dpjit, utils from numba_dpex.core.exceptions import UnreachableError from numba_dpex.core.runtime.context import DpexRTContext +from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext from numba_dpex.core.types import ( DpctlSyclEvent, DpnpNdArray, @@ -283,7 +284,7 @@ def _submit_kernel( kernel_module: _KernelModule = kernel_dispatcher.get_overload_kcres( kernel_sig ).kernel_device_ir_module - kernel_targetctx = kernel_dispatcher.targetctx + kernel_targetctx: DpexKernelTargetContext = kernel_dispatcher.targetctx def codegen(cgctx, builder, sig, llargs): # llargs[0] is kernel function that we don't need anymore (see above) @@ -317,16 +318,17 @@ def codegen(cgctx, builder, sig, llargs): index_arg=ll_index_space, ) - device_event_ref = kl.KernelLaunchIRBuilder( - kernel_targetctx, builder - ).submit_kernel( - kernel_ref=kernel_ref, - queue_ref=queue_ref, - kernel_args=kernel_args, - ty_kernel_args=ty_kernel_args_tuple, - global_range_extents=ll_range.global_range_extents, - local_range_extents=ll_range.local_range_extents, + kl_builder = kl.KernelLaunchIRBuilder( + cgctx, builder, kernel_targetctx.data_model_manager + ) + kl_builder.set_kernel(kernel_ref) + kl_builder.set_queue(queue_ref) + kl_builder.set_range( + ll_range.global_range_extents, ll_range.local_range_extents ) + kl_builder.set_arguments(ty_kernel_args_tuple, kernel_args) + kl_builder.set_dependant_event_list(dep_events=[]) + device_event_ref = kl_builder.submit() # Clean up sycl.dpctl_kernel_delete(builder, kernel_ref)