Skip to content

Commit

Permalink
Merge pull request #1148 from IntelPython/feature/range_ndrange_type_…
Browse files Browse the repository at this point in the history
…support

Adds Range and NdRange as supported types in numba_dpex.dpjit.
  • Loading branch information
Diptorup Deb authored Oct 2, 2023
2 parents 4334448 + 7e44902 commit cd5332f
Show file tree
Hide file tree
Showing 9 changed files with 685 additions and 0 deletions.
40 changes: 40 additions & 0 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
DpctlSyclEvent,
DpctlSyclQueue,
DpnpNdArray,
NdRangeType,
RangeType,
USMNdArray,
)

Expand Down Expand Up @@ -195,6 +197,39 @@ def __init__(self, dmm, fe_type):
super(SyclEventModel, self).__init__(dmm, fe_type, members)


class RangeModel(StructModel):
"""The native data model for a
numba_dpex.core.kernel_interface.indexers.Range PyObject.
"""

def __init__(self, dmm, fe_type):
members = [
("ndim", types.int64),
("dim0", types.int64),
("dim1", types.int64),
("dim2", types.int64),
]
super(RangeModel, self).__init__(dmm, fe_type, members)


class NdRangeModel(StructModel):
"""The native data model for a
numba_dpex.core.kernel_interface.indexers.NdRange PyObject.
"""

def __init__(self, dmm, fe_type):
members = [
("ndim", types.int64),
("gdim0", types.int64),
("gdim1", types.int64),
("gdim2", types.int64),
("ldim0", types.int64),
("ldim1", types.int64),
("ldim2", types.int64),
]
super(NdRangeModel, self).__init__(dmm, fe_type, members)


def _init_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a DpexKernelTarget-specific data model manager.
Expand Down Expand Up @@ -249,3 +284,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)
# Register the RangeType type
register_model(RangeType)(RangeModel)

# Register the NdRangeType type
register_model(NdRangeType)(NdRangeModel)
218 changes: 218 additions & 0 deletions numba_dpex/core/kernel_interface/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from collections.abc import Iterable

from llvmlite import ir as llvmir
from numba.core import cgutils, errors, types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload


class Range(tuple):
"""A data structure to encapsulate a single kernel launch parameter.
Expand All @@ -18,6 +23,8 @@ class Range(tuple):
the behavior of `sycl::range`.
"""

UNDEFINED_DIMENSION = -1

def __new__(cls, dim0, dim1=None, dim2=None):
"""Constructs a 1, 2, or 3 dimensional range.
Expand Down Expand Up @@ -74,6 +81,50 @@ def size(self):
else:
return self[0]

@property
def ndim(self) -> int:
"""Returns the rank of a Range object.
Returns:
int: Number of dimensions in the Range object
"""
return len(self)

@property
def dim0(self) -> int:
"""Return the extent of the first dimension for the Range object.
Returns:
int: Extent of first dimension for the Range object
"""
return self[0]

@property
def dim1(self) -> int:
"""Return the extent of the second dimension for the Range object.
Returns:
int: Extent of second dimension for the Range object or -1 for 1D
Range
"""
try:
return self[1]
except IndexError:
return Range.UNDEFINED_DIMENSION

@property
def dim2(self) -> int:
"""Return the extent of the second dimension for the Range object.
Returns:
int: Extent of second dimension for the Range object or -1 for 1D or
2D Range
"""
try:
return self[2]
except IndexError:
return Range.UNDEFINED_DIMENSION


class NdRange:
"""A class to encapsulate all kernel launch parameters.
Expand Down Expand Up @@ -169,3 +220,170 @@ def __repr__(self):
str: str representation for NdRange class.
"""
return self.__str__()

def __eq__(self, other):
if isinstance(other, NdRange):
return (
self.global_range == other.global_range
and self.local_range == other.local_range
)
else:
return False


@intrinsic
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):
ty_retty = ty_range.instance_type
sig = ty_retty(
ty_dim0,
ty_dim1,
ty_dim2,
ty_range,
)

def codegen(context, builder, sig, args):
typ = sig.return_type
dim0, dim1, dim2, _ = args
range_struct = cgutils.create_struct_proxy(typ)(context, builder)
range_struct.dim0 = dim0

if not isinstance(sig.args[1], types.NoneType):
range_struct.dim1 = dim1
else:
range_struct.dim1 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

if not isinstance(sig.args[2], types.NoneType):
range_struct.dim2 = dim2
else:
range_struct.dim2 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)

return range_struct._getvalue()

return sig, codegen


@intrinsic
def _intrin_ndrange_alloc(
typingctx, ty_global_range, ty_local_range, ty_ndrange
):
ty_retty = ty_ndrange.instance_type
sig = ty_retty(
ty_global_range,
ty_local_range,
ty_ndrange,
)
range_datamodel = default_manager.lookup(ty_global_range)

def codegen(context, builder, sig, args):
typ = sig.return_type

global_range, local_range, _ = args
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
ndrange_struct.ndim = llvmir.Constant(
llvmir.types.IntType(64), typ.ndim
)
ndrange_struct.gdim0 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim0"),
)
ndrange_struct.gdim1 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim1"),
)
ndrange_struct.gdim2 = builder.extract_value(
global_range,
range_datamodel.get_field_position("dim2"),
)
ndrange_struct.ldim0 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim0"),
)
ndrange_struct.ldim1 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim1"),
)
ndrange_struct.ldim2 = builder.extract_value(
local_range,
range_datamodel.get_field_position("dim2"),
)

return ndrange_struct._getvalue()

return sig, codegen


@overload(Range)
def _ol_range_init(dim0, dim1=None, dim2=None):
"""Numba overload of the Range constructor to make it usable inside an
njit and dpjit decorated function.
"""
from numba_dpex.core.types import RangeType

ndims = 1
ty_optional_dims = (dim1, dim2)

# A Range should at least have the 0th dimension populated
if not isinstance(dim0, types.Integer):
raise errors.TypingError(
"Expected a Range's dimension should to be an Integer value, but "
"encountered " + dim0.name
)

for ty_dim in ty_optional_dims:
if isinstance(ty_dim, types.Integer):
ndims += 1
elif ty_dim is not None:
raise errors.TypingError(
"Expected a Range's dimension to be an Integer value, "
f"but {type(ty_dim)} was provided."
)

ret_ty = RangeType(ndims)

def impl(dim0, dim1=None, dim2=None):
return _intrin_range_alloc(dim0, dim1, dim2, ret_ty)

return impl


@overload(NdRange)
def _ol_ndrange_init(global_range, local_range):
"""Numba overload of the NdRange constructor to make it usable inside an
njit and dpjit decorated function.
"""
from numba_dpex.core.exceptions import UnmatchedNumberOfRangeDimsError
from numba_dpex.core.types import NdRangeType, RangeType

if not isinstance(global_range, RangeType):
raise errors.TypingError(
"Only global range values specified as a Range are "
"supported inside dpjit"
)

if not isinstance(local_range, RangeType):
raise errors.TypingError(
"Only local range values specified as a Range are "
"supported inside dpjit"
)

if not global_range.ndim == local_range.ndim:
raise UnmatchedNumberOfRangeDimsError(
kernel_name="",
global_ndims=global_range.ndim,
local_ndims=local_range.ndim,
)

ret_ty = NdRangeType(global_range.ndim)

def impl(global_range, local_range):
return _intrin_ndrange_alloc(global_range, local_range, ret_ty)

return impl
3 changes: 3 additions & 0 deletions numba_dpex/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
uint64,
void,
)
from .range_types import NdRangeType, RangeType
from .usm_ndarray_type import USMNdArray

usm_ndarray = USMNdArray
Expand All @@ -35,6 +36,8 @@
"DpctlSyclQueue",
"DpctlSyclEvent",
"DpnpNdArray",
"RangeType",
"NdRangeType",
"USMNdArray",
"none",
"boolean",
Expand Down
Loading

0 comments on commit cd5332f

Please sign in to comment.