Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
    - A static variable Range.UNDEFINED_DIMENSION is now used
      instead of the magic -1 value to indicate that the dimension
      extent is undefined.
    - An ndim_obj need not be created when boxing a Range.
  • Loading branch information
Diptorup Deb committed Sep 30, 2023
1 parent 2929350 commit 7e44902
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
29 changes: 19 additions & 10 deletions numba_dpex/core/kernel_interface/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Range(tuple):
the behavior of `sycl::range`.
"""

UNDEFINED_DIMENSION = -1

def __new__(cls, dim0, dim1=None, dim2=None):
"""Constructs a 1, 2, or 3 dimensional range.
Expand Down Expand Up @@ -107,8 +109,8 @@ def dim1(self) -> int:
"""
try:
return self[1]
except:
return -1
except IndexError:
return Range.UNDEFINED_DIMENSION

@property
def dim2(self) -> int:
Expand All @@ -120,8 +122,8 @@ def dim2(self) -> int:
"""
try:
return self[2]
except:
return -1
except IndexError:
return Range.UNDEFINED_DIMENSION


class NdRange:
Expand Down Expand Up @@ -220,10 +222,13 @@ def __repr__(self):
return self.__str__()

def __eq__(self, other):
return (
self.global_range == other.global_range
and self.local_range == other.local_range
)
if isinstance(other, NdRange):
return (
self.global_range == other.global_range
and self.local_range == other.local_range
)
else:
return False


@intrinsic
Expand All @@ -245,12 +250,16 @@ def codegen(context, builder, sig, args):
if not isinstance(sig.args[1], types.NoneType):
range_struct.dim1 = dim1
else:
range_struct.dim1 = llvmir.Constant(llvmir.types.IntType(64), -1)
range_struct.dim1 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

if not isinstance(sig.args[2], types.NoneType):
range_struct.dim2 = dim2
else:
range_struct.dim2 = llvmir.Constant(llvmir.types.IntType(64), -1)
range_struct.dim2 = llvmir.Constant(
llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION
)

range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim)

Expand Down
5 changes: 0 additions & 5 deletions numba_dpex/core/types/range_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def box_range(typ, val, c):
c.context, c.builder, value=val
)

ndim_obj = c.box(types.int64, range_struct.ndim)
with cgutils.early_exit_if_null(c.builder, stack, ndim_obj):
c.builder.store(fail_obj, ret_ptr)
dim0_obj = c.box(types.int64, range_struct.dim0)
with cgutils.early_exit_if_null(c.builder, stack, dim0_obj):
c.builder.store(fail_obj, ret_ptr)
Expand All @@ -184,7 +181,6 @@ def box_range(typ, val, c):

class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Range))
with cgutils.early_exit_if_null(c.builder, stack, class_obj):
c.pyapi.decref(ndim_obj)
c.pyapi.decref(dim0_obj)
c.pyapi.decref(dim1_obj)
c.pyapi.decref(dim2_obj)
Expand All @@ -203,7 +199,6 @@ def box_range(typ, val, c):
else:
raise ValueError("Cannot unbox Range instance.")

c.pyapi.decref(ndim_obj)
c.pyapi.decref(dim0_obj)
c.pyapi.decref(dim1_obj)
c.pyapi.decref(dim2_obj)
Expand Down
8 changes: 3 additions & 5 deletions numba_dpex/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@

from numba_dpex.utils import address_space

<<<<<<< HEAD
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
=======
from ..kernel_interface.indexers import NdRange, Range
from ..types.dpctl_types import DpctlSyclQueue
>>>>>>> 1d5800cf8 (Adds Range and NdRange as supported types in numba_dpex.dpjit.)
from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue
from ..types.dpnp_ndarray_type import DpnpNdArray
from ..types.range_types import NdRangeType, RangeType
from ..types.usm_ndarray_type import USMNdArray
Expand Down Expand Up @@ -126,6 +122,8 @@ def typeof_dpctl_sycl_event(val, c):
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance.
"""
return DpctlSyclEvent(val)


@typeof_impl.register(Range)
def typeof_range(val, c):
"""Registers the type inference implementation function for a
Expand Down

0 comments on commit 7e44902

Please sign in to comment.