Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes the numpy_usm_shared module from numba_dpex. #841

Merged
merged 1 commit into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ exclude =
.git,
__pycache__,
_version.py,
numpy_usm_shared.py,
lowerer.py,
110 changes: 45 additions & 65 deletions numba_dpex/dpctl_iface/kernel_launch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from numba.core import cgutils, types
from numba.core.ir_utils import legalize_names

from numba_dpex import numpy_usm_shared as nus
from numba_dpex import utils
from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
Expand Down Expand Up @@ -251,79 +250,60 @@ def process_kernel_arg(
context=self.context, type=types.voidptr
)

if isinstance(arg_type, nus.UsmSharedArrayType):
self._form_kernel_arg_and_arg_ty(
malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared(
builder=self.builder, context=self.context
)
memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy(
builder=self.builder, context=self.context
)
event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete(
builder=self.builder, context=self.context
)
event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait(
builder=self.builder, context=self.context
)

# Not known to be USM so we need to copy to USM.
buffer_name = "buffer_ptr" + str(self.cur_arg)
# Create void * to hold new USM buffer.
buffer_ptr = cgutils.alloca_once(
self.builder,
utils.get_llvm_type(context=self.context, type=types.voidptr),
name=buffer_name,
)
# Setup the args to the USM allocator, size and SYCL queue.
args = [
self.builder.load(total_size),
self.builder.load(sycl_queue_val),
]
# Call USM shared allocator and store in buffer_ptr.
self.builder.store(self.builder.call(malloc_fn, args), buffer_ptr)

if legal_names[var] in modified_arrays:
self.write_buffs.append((buffer_ptr, total_size, data_member))
else:
self.read_only_buffs.append(
(buffer_ptr, total_size, data_member)
)

# We really need to detect when an array needs to be copied over
if index < self.num_inputs:
args = [
self.builder.load(sycl_queue_val),
self.builder.load(buffer_ptr),
self.builder.bitcast(
self.builder.load(data_member),
utils.get_llvm_type(
context=self.context, type=types.voidptr
),
),
ty,
)
else:
malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared(
builder=self.builder, context=self.context
)
memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy(
builder=self.builder, context=self.context
)
event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete(
builder=self.builder, context=self.context
)
event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait(
builder=self.builder, context=self.context
)

# Not known to be USM so we need to copy to USM.
buffer_name = "buffer_ptr" + str(self.cur_arg)
# Create void * to hold new USM buffer.
buffer_ptr = cgutils.alloca_once(
self.builder,
utils.get_llvm_type(
context=self.context, type=types.voidptr
),
name=buffer_name,
)
# Setup the args to the USM allocator, size and SYCL queue.
args = [
self.builder.load(total_size),
self.builder.load(sycl_queue_val),
]
# Call USM shared allocator and store in buffer_ptr.
self.builder.store(
self.builder.call(malloc_fn, args), buffer_ptr
)

if legal_names[var] in modified_arrays:
self.write_buffs.append(
(buffer_ptr, total_size, data_member)
)
else:
self.read_only_buffs.append(
(buffer_ptr, total_size, data_member)
)

# We really need to detect when an array needs to be copied over
if index < self.num_inputs:
args = [
self.builder.load(sycl_queue_val),
self.builder.load(buffer_ptr),
self.builder.bitcast(
self.builder.load(data_member),
utils.get_llvm_type(
context=self.context, type=types.voidptr
),
),
self.builder.load(total_size),
]
event_ref = self.builder.call(memcpy_fn, args)
self.builder.call(event_wait_fn, [event_ref])
self.builder.call(event_del_fn, [event_ref])
event_ref = self.builder.call(memcpy_fn, args)
self.builder.call(event_wait_fn, [event_ref])
self.builder.call(event_del_fn, [event_ref])

self._form_kernel_arg_and_arg_ty(
self.builder.load(buffer_ptr), ty
)
self._form_kernel_arg_and_arg_ty(self.builder.load(buffer_ptr), ty)

# Handle shape
shape_member = self.builder.gep(
Expand Down
Loading