Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/dpctl slices #1425

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numba.core.target_extension import CPU, target_registry

from numba_dpex.core.datamodel.models import _init_dpjit_data_model_manager
from numba_dpex.dpctl_iface import dpctlimpl
from numba_dpex.dpnp_iface import dpnp_ufunc_db


Expand Down Expand Up @@ -69,6 +70,7 @@ def load_additional_registries(self):
Load dpjit-specific registries.
"""
self.install_registry(dpex_function_registry)
self.install_registry(dpctlimpl.registry)

# loading CPU specific registries
super().load_additional_registries()
Expand Down
2 changes: 2 additions & 0 deletions numba_dpex/dpctl_iface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
The ``dpctl_iface`` module implements Numba's interface to the libsyclinterface
library that provides C bindings to DPC++'s SYCL runtime API.
"""

from . import arrayobj
148 changes: 148 additions & 0 deletions numba_dpex/dpctl_iface/arrayobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import operator

from llvmlite.ir import IRBuilder
from numba.core import cgutils, errors, imputils, types
from numba.core.imputils import impl_ret_borrowed
from numba.extending import intrinsic, overload_attribute
from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic
from numba.np.arrayobj import make_array

from numba_dpex.core.types import DpnpNdArray, USMNdArray
from numba_dpex.core.types.dpctl_types import DpctlSyclQueue
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
_getitem_array_generic as kernel_getitem_array_generic,
)
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from .dpctlimpl import lower_builtin

# can't import name because of the circular import
DPEX_TARGET_NAME = "dpex"

# =========================================================================
# Helps to parse dpnp constructor arguments
# =========================================================================


# TODO: target specific
@lower_builtin(operator.getitem, USMNdArray, types.Integer)
@lower_builtin(operator.getitem, USMNdArray, types.SliceType)
def getitem_arraynd_intp(context, builder, sig, args):
"""
Overrding the numba.np.arrayobj.getitem_arraynd_intp to support dpnp.ndarray

The data model for numba.types.Array and numba_dpex.types.DpnpNdArray
are different. DpnpNdArray has an extra attribute to store a sycl::queue
pointer. For that reason, np_getitem_arraynd_intp needs to be overriden so
that when returning a view of a dpnp.ndarray the sycl::queue pointer
member in the LLVM IR struct gets properly updated.
"""
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
_getitem_array_generic = np_getitem_array_generic

if getitem_call_in_kernel:
_getitem_array_generic = kernel_getitem_array_generic

aryty, idxty = sig.args
ary, idx = args

assert aryty.ndim >= 1
ary = make_array(aryty)(context, builder, ary)

res = _getitem_array_generic(
context, builder, sig.return_type, aryty, ary, (idxty,), (idx,)
)
ret = impl_ret_borrowed(context, builder, sig.return_type, res)

if isinstance(sig.return_type, USMNdArray) and not getitem_call_in_kernel:
array_val = args[0]
array_ty = sig.args[0]
sycl_queue_attr_pos = context.data_model_manager.lookup(
array_ty
).get_field_position("sycl_queue")
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)

return ret


@intrinsic(target=DPEX_TARGET_NAME)
def ol_usm_nd_array_sycl_queue(
ty_context,
ty_dpnp_nd_array: DpnpNdArray,
):
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
raise errors.TypingError("Argument must be DpnpNdArray")

ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue

sig = ty_queue(ty_dpnp_nd_array)

def codegen(context, builder: IRBuilder, sig, args: list):
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
context,
builder,
value=args[0],
)

queue_ref = array_proxy.sycl_queue

queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
context, builder
)

queue_struct_proxy.queue_ref = queue_ref
queue_struct_proxy.meminfo = array_proxy.meminfo

# Warning: current implementation prevents whole object from being
# destroyed as long as sycl_queue attribute is being used. It should be
# okay since anywere we use it as an argument callee creates a copy
# so it does not steel reference.
#
# We can avoid it by:
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
# queue_struct->meminfo =
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
# but it will allocate new meminfo object which can negatively affect
# performance.
# Speaking philosophically attribute is a part of the object and as long
# as nobody can still the reference it is a part of the owner object
# and lifetime is tied to it.
# TODO: we want to have queue: queuestruct_t instead of
# queue_ref: QueueRef as an attribute for DPNPNdArray.

queue_value = queue_struct_proxy._getvalue()

# We need to incref meminfo so that queue model is preventing parent
# ndarray from being destroyed, that can destroy queue that we are
# using.
return imputils.impl_ret_borrowed(
context, builder, ty_queue, queue_value
)

return sig, codegen


@overload_attribute(USMNdArray, "sycl_queue", target=DPEX_TARGET_NAME)
def dpnp_nd_array_sycl_queue(arr):
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.

This is an overloaded attribute implementation for dpnp.sycl_queue.

Args:
arr (numba_dpex.core.types.DpnpNdArray): Input array from which to
take sycl_queue.

Returns:
function: Local function `ol_dpnp_nd_array_sycl_queue()`.
"""

def get(arr):
return ol_usm_nd_array_sycl_queue(arr)

return get
15 changes: 15 additions & 0 deletions numba_dpex/dpctl_iface/dpctlimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core.imputils import Registry

registry = Registry("dpctlimpl")

lower_builtin = registry.lower
lower_getattr = registry.lower_getattr
lower_getattr_generic = registry.lower_getattr_generic
lower_setattr = registry.lower_setattr
lower_setattr_generic = registry.lower_setattr_generic
lower_cast = registry.lower_cast
lower_constant = registry.lower_constant
61 changes: 1 addition & 60 deletions numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from dpctl import get_device_cached_queue
from llvmlite import ir as llvmir
from llvmlite.ir import Constant, IRBuilder
from llvmlite.ir import Constant
from llvmlite.ir.types import DoubleType, FloatType
from numba import types
from numba.core import cgutils
from numba.core import config as numba_config
from numba.core import errors, imputils
from numba.core.typing import signature
from numba.extending import intrinsic, overload_classmethod
from numba.np.arrayobj import (
Expand Down Expand Up @@ -1081,61 +1080,3 @@ def codegen(context, builder, sig, args):
return ary._getvalue()

return signature, codegen


@intrinsic(target=DPEX_TARGET_NAME)
def ol_dpnp_nd_array_sycl_queue(
ty_context,
ty_dpnp_nd_array: DpnpNdArray,
):
if not isinstance(ty_dpnp_nd_array, DpnpNdArray):
raise errors.TypingError("Argument must be DpnpNdArray")

ty_queue: DpctlSyclQueue = ty_dpnp_nd_array.queue

sig = ty_queue(ty_dpnp_nd_array)

def codegen(context, builder: IRBuilder, sig, args: list):
array_proxy = cgutils.create_struct_proxy(ty_dpnp_nd_array)(
context,
builder,
value=args[0],
)

queue_ref = array_proxy.sycl_queue

queue_struct_proxy = cgutils.create_struct_proxy(ty_queue)(
context, builder
)

queue_struct_proxy.queue_ref = queue_ref
queue_struct_proxy.meminfo = array_proxy.meminfo

# Warning: current implementation prevents whole object from being
# destroyed as long as sycl_queue attribute is being used. It should be
# okay since anywere we use it as an argument callee creates a copy
# so it does not steel reference.
#
# We can avoid it by:
# queue_ref_copy = sycl.dpctl_queue_copy(builder, queue_ref) #noqa E800
# queue_struct_proxy.queue_ref = queue_ref_copy #noqa E800
# queue_struct->meminfo =
# nrt->manage_memory(queue_ref_copy, DPCTLEvent_Delete);
# but it will allocate new meminfo object which can negatively affect
# performance.
# Speaking philosophically attribute is a part of the object and as long
# as nobody can still the reference it is a part of the owner object
# and lifetime is tied to it.
# TODO: we want to have queue: queuestruct_t instead of
# queue_ref: QueueRef as an attribute for DPNPNdArray.

queue_value = queue_struct_proxy._getvalue()

# We need to incref meminfo so that queue model is preventing parent
# ndarray from being destroyed, that can destroy queue that we are
# using.
return imputils.impl_ret_borrowed(
context, builder, ty_queue, queue_value
)

return sig, codegen
74 changes: 1 addition & 73 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

import operator

import dpnp
from numba import errors, types
from numba.core.imputils import impl_ret_borrowed, lower_builtin
from numba.core.types import scalars
from numba.core.types.containers import UniTuple
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.extending import overload, overload_attribute
from numba.np.arrayobj import _getitem_array_generic as np_getitem_array_generic
from numba.np.arrayobj import make_array
from numba.extending import overload
from numba.np.numpy_support import is_nonelike

from numba_dpex.core.types import DpnpNdArray
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
_getitem_array_generic as kernel_getitem_array_generic,
)
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from ._intrinsic import (
impl_dpnp_empty,
Expand All @@ -31,7 +22,6 @@
impl_dpnp_ones_like,
impl_dpnp_zeros,
impl_dpnp_zeros_like,
ol_dpnp_nd_array_sycl_queue,
)

# can't import name because of the circular import
Expand Down Expand Up @@ -1067,65 +1057,3 @@ def impl(
"Cannot parse input types to "
+ f"function dpnp.full_like({x1}, {fill_value}, {dtype}, ...)."
)


# TODO: target specific
@lower_builtin(operator.getitem, DpnpNdArray, types.Integer)
@lower_builtin(operator.getitem, DpnpNdArray, types.SliceType)
def getitem_arraynd_intp(context, builder, sig, args):
"""
Overrding the numba.np.arrayobj.getitem_arraynd_intp to support dpnp.ndarray

The data model for numba.types.Array and numba_dpex.types.DpnpNdArray
are different. DpnpNdArray has an extra attribute to store a sycl::queue
pointer. For that reason, np_getitem_arraynd_intp needs to be overriden so
that when returning a view of a dpnp.ndarray the sycl::queue pointer
member in the LLVM IR struct gets properly updated.
"""
getitem_call_in_kernel = isinstance(context, SPIRVTargetContext)
_getitem_array_generic = np_getitem_array_generic

if getitem_call_in_kernel:
_getitem_array_generic = kernel_getitem_array_generic

aryty, idxty = sig.args
ary, idx = args

assert aryty.ndim >= 1
ary = make_array(aryty)(context, builder, ary)

res = _getitem_array_generic(
context, builder, sig.return_type, aryty, ary, (idxty,), (idx,)
)
ret = impl_ret_borrowed(context, builder, sig.return_type, res)

if isinstance(sig.return_type, DpnpNdArray) and not getitem_call_in_kernel:
array_val = args[0]
array_ty = sig.args[0]
sycl_queue_attr_pos = context.data_model_manager.lookup(
array_ty
).get_field_position("sycl_queue")
sycl_queue_attr = builder.extract_value(array_val, sycl_queue_attr_pos)
ret = builder.insert_value(ret, sycl_queue_attr, sycl_queue_attr_pos)

return ret


@overload_attribute(DpnpNdArray, "sycl_queue", target=DPEX_TARGET_NAME)
def dpnp_nd_array_sycl_queue(arr):
"""Returns :class:`dpctl.SyclQueue` object associated with USM data.

This is an overloaded attribute implementation for dpnp.sycl_queue.

Args:
arr (numba_dpex.core.types.DpnpNdArray): Input array from which to
take sycl_queue.

Returns:
function: Local function `ol_dpnp_nd_array_sycl_queue()`.
"""

def get(arr):
return ol_dpnp_nd_array_sycl_queue(arr)

return get
2 changes: 2 additions & 0 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,15 @@ def load_additional_registries(self):
"""
# pylint: disable=import-outside-toplevel
from numba_dpex import printimpl
from numba_dpex.dpctl_iface import dpctlimpl
from numba_dpex.dpnp_iface import dpnpimpl
from numba_dpex.ocl import mathimpl, oclimpl

self.insert_func_defn(oclimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
self.insert_func_defn(dpnpimpl.registry.functions)
self.install_registry(printimpl.registry)
self.install_registry(dpctlimpl.registry)
self.install_registry(spirv_registry)
# Replace dpnp math functions with their OpenCL versions.
self.replace_dpnp_ufunc_with_ocl_intrinsics()
Expand Down
Loading
Loading