Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Range and NdRange as supported types in numba_dpex.dpjit. #1148

Merged
merged 6 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
DpctlSyclEvent,
DpctlSyclQueue,
DpnpNdArray,
NdRangeType,
RangeType,
USMNdArray,
)

Expand Down Expand Up @@ -195,6 +197,39 @@ def __init__(self, dmm, fe_type):
super(SyclEventModel, self).__init__(dmm, fe_type, members)


class RangeModel(StructModel):
"""The native data model for a
numba_dpex.core.kernel_interface.indexers.Range PyObject.
"""

def __init__(self, dmm, fe_type):
members = [
("ndim", types.int64),
("dim0", types.int64),
("dim1", types.int64),
("dim2", types.int64),
]
super(RangeModel, self).__init__(dmm, fe_type, members)


class NdRangeModel(StructModel):
"""The native data model for a
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
"""

def __init__(self, dmm, fe_type):
members = [
("ndim", types.int64),
("gdim0", types.int64),
("gdim1", types.int64),
("gdim2", types.int64),
("ldim0", types.int64),
("ldim1", types.int64),
("ldim2", types.int64),
]
super(NdRangeModel, self).__init__(dmm, fe_type, members)


def _init_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a DpexKernelTarget-specific data model manager.

Expand Down Expand Up @@ -249,3 +284,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)
# Register the RangeType type
register_model(RangeType)(RangeModel)
ZzEeKkAa marked this conversation as resolved.
Show resolved Hide resolved

# Register the NdRangeType type
register_model(NdRangeType)(NdRangeModel)
218 changes: 218 additions & 0 deletions numba_dpex/core/kernel_interface/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from collections.abc import Iterable

from llvmlite import ir as llvmir
from numba.core import cgutils, errors, types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload


class Range(tuple):
"""A data structure to encapsulate a single kernel launch parameter.
Expand All @@ -18,6 +23,8 @@ class Range(tuple):
the behavior of `sycl::range`.
"""

UNDEFINED_DIMENSION = -1

def __new__(cls, dim0, dim1=None, dim2=None):
"""Constructs a 1, 2, or 3 dimensional range.

Expand Down Expand Up @@ -74,6 +81,50 @@ def size(self):
else:
return self[0]

@property
def ndim(self) -> int:
"""Returns the rank of a Range object.

Returns:
int: Number of dimensions in the Range object
"""
return len(self)

@property
def dim0(self) -> int:
"""Return the extent of the first dimension for the Range object.

Returns:
int: Extent of first dimension for the Range object
"""
return self[0]

@property
def dim1(self) -> int:
"""Return the extent of the second dimension for the Range object.

Returns:
int: Extent of second dimension for the Range object or -1 for 1D
Range
"""
try:
return self[1]
except IndexError:
return Range.UNDEFINED_DIMENSION

@property
def dim2(self) -> int:
"""Return the extent of the second dimension for the Range object.

Returns:
int: Extent of second dimension for the Range object or -1 for 1D or
2D Range
"""
try:
return self[2]
except IndexError:
return Range.UNDEFINED_DIMENSION


class NdRange:
"""A class to encapsulate all kernel launch parameters.
Expand Down Expand Up @@ -169,3 +220,170 @@ def __repr__(self):
str: str representation for NdRange class.
"""
return self.__str__()

def __eq__(self, other):
ZzEeKkAa marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(other, NdRange):
return (
self.global_range == other.global_range
and self.local_range == other.local_range
)
else:
return False


@intrinsic
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
ty_retty = ty_range.instance_type
sig = ty_retty(
ty_dim0,
ty_dim1,
ty_dim2,
ty_range,
)

def codegen(context, builder, sig, args):
typ = sig.return_type
dim0, dim1, dim2, _ = args
range_struct = cgutils.create_struct_proxy(typ)(context, builder)
range_struct.dim0 = dim0

if not isinstance(sig.args[1], types.NoneType):
range_struct.dim1 = dim1
else:
range_struct.dim1 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

if not isinstance(sig.args[2], types.NoneType):
range_struct.dim2 = dim2
else:
range_struct.dim2 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)

return range_struct._getvalue()

return sig, codegen


@intrinsic
def _intrin_ndrange_alloc(
typingctx, ty_global_range, ty_local_range, ty_ndrange
):
ty_retty = ty_ndrange.instance_type
sig = ty_retty(
ty_global_range,
ty_local_range,
ty_ndrange,
)
range_datamodel = default_manager.lookup(ty_global_range)

def codegen(context, builder, sig, args):
typ = sig.return_type

global_range, local_range, _ = args
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
ndrange_struct.ndim = llvmir.Constant(
llvmir.types.IntType(64), typ.ndim
)
ndrange_struct.gdim0 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim0"),
)
ndrange_struct.gdim1 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim1"),
)
ndrange_struct.gdim2 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim2"),
)
ndrange_struct.ldim0 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim0"),
)
ndrange_struct.ldim1 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim1"),
)
ndrange_struct.ldim2 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim2"),
)

return ndrange_struct._getvalue()

return sig, codegen


@overload(Range)
def _ol_range_init(dim0, dim1=None, dim2=None):
"""Numba overload of the Range constructor to make it usable inside an
njit and dpjit decorated function.

"""
from numba_dpex.core.types import RangeType

ndims = 1
ty_optional_dims = (dim1, dim2)

# A Range should at least have the 0th dimension populated
if not isinstance(dim0, types.Integer):
raise errors.TypingError(
"Expected a Range's dimension should to be an Integer value, but "
"encountered " + dim0.name
)

for ty_dim in ty_optional_dims:
if isinstance(ty_dim, types.Integer):
ndims += 1
elif ty_dim is not None:
raise errors.TypingError(
"Expected a Range's dimension to be an Integer value, "
f"but {type(ty_dim)} was provided."
)

ret_ty = RangeType(ndims)

def impl(dim0, dim1=None, dim2=None):
return _intrin_range_alloc(dim0, dim1, dim2, ret_ty)

return impl


@overload(NdRange)
def _ol_ndrange_init(global_range, local_range):
"""Numba overload of the NdRange constructor to make it usable inside an
njit and dpjit decorated function.

"""
from numba_dpex.core.exceptions import UnmatchedNumberOfRangeDimsError
from numba_dpex.core.types import NdRangeType, RangeType

if not isinstance(global_range, RangeType):
raise errors.TypingError(
"Only global range values specified as a Range are "
"supported inside dpjit"
)

if not isinstance(local_range, RangeType):
raise errors.TypingError(
"Only local range values specified as a Range are "
"supported inside dpjit"
)

if not global_range.ndim == local_range.ndim:
raise UnmatchedNumberOfRangeDimsError(
kernel_name="",
global_ndims=global_range.ndim,
local_ndims=local_range.ndim,
)

ret_ty = NdRangeType(global_range.ndim)

def impl(global_range, local_range):
return _intrin_ndrange_alloc(global_range, local_range, ret_ty)

return impl
3 changes: 3 additions & 0 deletions numba_dpex/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
uint64,
void,
)
from .range_types import NdRangeType, RangeType
from .usm_ndarray_type import USMNdArray

usm_ndarray = USMNdArray
Expand All @@ -35,6 +36,8 @@
"DpctlSyclQueue",
"DpctlSyclEvent",
"DpnpNdArray",
"RangeType",
"NdRangeType",
"USMNdArray",
"none",
"boolean",
Expand Down
Loading
Loading