From 7e449027775c7672397b1ae8376ea18ad0f00196 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 30 Sep 2023 10:25:10 -0500 Subject: [PATCH] Address review comments. - A static variable Range.UNDEFINED_DIMENSION is now used instead of the magic -1 value to indicate that the dimension extent is undefined. - An ndim_obj need not be created when boxing a Range. --- numba_dpex/core/kernel_interface/indexers.py | 29 +++++++++++++------- numba_dpex/core/types/range_types.py | 5 ---- numba_dpex/core/typing/typeof.py | 8 ++---- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/numba_dpex/core/kernel_interface/indexers.py b/numba_dpex/core/kernel_interface/indexers.py index 7843fdf160..33498afbbb 100644 --- a/numba_dpex/core/kernel_interface/indexers.py +++ b/numba_dpex/core/kernel_interface/indexers.py @@ -23,6 +23,8 @@ class Range(tuple): the behavior of `sycl::range`. """ + UNDEFINED_DIMENSION = -1 + def __new__(cls, dim0, dim1=None, dim2=None): """Constructs a 1, 2, or 3 dimensional range. @@ -107,8 +109,8 @@ def dim1(self) -> int: """ try: return self[1] - except: - return -1 + except IndexError: + return Range.UNDEFINED_DIMENSION @property def dim2(self) -> int: @@ -120,8 +122,8 @@ def dim2(self) -> int: """ try: return self[2] - except: - return -1 + except IndexError: + return Range.UNDEFINED_DIMENSION class NdRange: @@ -220,10 +222,13 @@ def __repr__(self): return self.__str__() def __eq__(self, other): - return ( - self.global_range == other.global_range - and self.local_range == other.local_range - ) + if isinstance(other, NdRange): + return ( + self.global_range == other.global_range + and self.local_range == other.local_range + ) + else: + return False @intrinsic @@ -245,12 +250,16 @@ def codegen(context, builder, sig, args): if not isinstance(sig.args[1], types.NoneType): range_struct.dim1 = dim1 else: - range_struct.dim1 = llvmir.Constant(llvmir.types.IntType(64), -1) + range_struct.dim1 = llvmir.Constant( + llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION + ) if not isinstance(sig.args[2], types.NoneType): range_struct.dim2 = dim2 else: - range_struct.dim2 = llvmir.Constant(llvmir.types.IntType(64), -1) + range_struct.dim2 = llvmir.Constant( + llvmir.types.IntType(64), Range.UNDEFINED_DIMENSION + ) range_struct.ndim = llvmir.Constant(llvmir.types.IntType(64), typ.ndim) diff --git a/numba_dpex/core/types/range_types.py b/numba_dpex/core/types/range_types.py index a5ee9428d2..81d451215b 100644 --- a/numba_dpex/core/types/range_types.py +++ b/numba_dpex/core/types/range_types.py @@ -169,9 +169,6 @@ def box_range(typ, val, c): c.context, c.builder, value=val ) - ndim_obj = c.box(types.int64, range_struct.ndim) - with cgutils.early_exit_if_null(c.builder, stack, ndim_obj): - c.builder.store(fail_obj, ret_ptr) dim0_obj = c.box(types.int64, range_struct.dim0) with cgutils.early_exit_if_null(c.builder, stack, dim0_obj): c.builder.store(fail_obj, ret_ptr) @@ -184,7 +181,6 @@ def box_range(typ, val, c): class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Range)) with cgutils.early_exit_if_null(c.builder, stack, class_obj): - c.pyapi.decref(ndim_obj) c.pyapi.decref(dim0_obj) c.pyapi.decref(dim1_obj) c.pyapi.decref(dim2_obj) @@ -203,7 +199,6 @@ def box_range(typ, val, c): else: raise ValueError("Cannot unbox Range instance.") - c.pyapi.decref(ndim_obj) c.pyapi.decref(dim0_obj) c.pyapi.decref(dim1_obj) c.pyapi.decref(dim2_obj) diff --git a/numba_dpex/core/typing/typeof.py b/numba_dpex/core/typing/typeof.py index 4668f8d9f4..5ef88a8d0c 100644 --- a/numba_dpex/core/typing/typeof.py +++ b/numba_dpex/core/typing/typeof.py @@ -10,12 +10,8 @@ from numba_dpex.utils import address_space -<<<<<<< HEAD -from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue -======= from ..kernel_interface.indexers import NdRange, Range -from ..types.dpctl_types import DpctlSyclQueue ->>>>>>> 1d5800cf8 (Adds Range and NdRange as supported types in numba_dpex.dpjit.) +from ..types.dpctl_types import DpctlSyclEvent, DpctlSyclQueue from ..types.dpnp_ndarray_type import DpnpNdArray from ..types.range_types import NdRangeType, RangeType from ..types.usm_ndarray_type import USMNdArray @@ -126,6 +122,8 @@ def typeof_dpctl_sycl_event(val, c): Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance. """ return DpctlSyclEvent(val) + + @typeof_impl.register(Range) def typeof_range(val, c): """Registers the type inference implementation function for a