Skip to content

Commit

Permalink
Merge pull request #1317 from IntelPython/refactor/migrate_int_enum_l…
Browse files Browse the repository at this point in the history
…iteral

Migrate IntEnumLiteral into core.
  • Loading branch information
Diptorup Deb authored Feb 6, 2024
2 parents ccb7606 + 269e08e commit cf4b631
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 83 deletions.
51 changes: 50 additions & 1 deletion numba_dpex/_kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from numba.core.registry import cpu_target
from numba.core.target_extension import GPU, target_registry
from numba.core.types import Array as NpArrayType
from numba.core.types.scalars import IntEnumClass

from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
from numba_dpex.core.typeconv import to_usm_ndarray
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.types import IntEnumLiteral, USMNdArray
from numba_dpex.core.utils import get_info_from_suai
from numba_dpex.kernel_api.flag_enum import FlagEnum
from numba_dpex.utils import address_space, calling_conv

from . import codegen
Expand Down Expand Up @@ -64,6 +66,37 @@ class SPIRVTypingContext(typing.BaseContext):
"""

def resolve_value_type(self, val):
"""
Return the numba type of a Python value that is being used
as a runtime constant.
ValueError is raised for unsupported types.
"""

typ = super().resolve_value_type(val)

if isinstance(typ, IntEnumClass) and issubclass(val, FlagEnum):
typ = IntEnumLiteral(val)

return typ

def resolve_getattr(self, typ, attr):
"""
Resolve getting the attribute *attr* (a string) on the Numba type.
The attribute's type is returned, or None if resolution failed.
"""
retty = None

if isinstance(typ, IntEnumLiteral):
try:
attrval = getattr(typ.literal_value, attr).value
retty = types.IntegerLiteral(attrval)
except ValueError:
pass
else:
retty = super().resolve_getattr(typ, attr)
return retty

def resolve_argument_type(self, val):
"""Return the Numba type of a Python value used as a function argument.
Expand Down Expand Up @@ -269,6 +302,22 @@ def init(self):

self.ufunc_db = _dpnp_ufunc_db

def get_getattr(self, typ, attr):
"""
Overrides the get_getattr function to provide an implementation for
getattr call on an IntegerEnumLiteral type.
"""

if isinstance(typ, IntEnumLiteral):
# pylint: disable=W0613
def enum_literal_getattr_imp(context, builder, typ, val, attr):
enum_attr_value = getattr(typ.literal_value, attr).value
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)

return enum_literal_getattr_imp

return super().get_getattr(typ, attr)

def create_module(self, name):
return self._internal_codegen._create_empty_module(name)

Expand Down
20 changes: 17 additions & 3 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from llvmlite import ir as llvmir
from numba.core import datamodel, types
from numba.core.datamodel.models import PrimitiveModel, StructModel
from numba.core.extending import register_model
Expand All @@ -14,6 +15,7 @@
DpctlSyclEvent,
DpctlSyclQueue,
DpnpNdArray,
IntEnumLiteral,
NdRangeType,
RangeType,
USMNdArray,
Expand Down Expand Up @@ -55,6 +57,17 @@ def __init__(self, dmm, fe_type):
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)


class IntEnumLiteralModel(PrimitiveModel):
"""Representation of an object of LiteralIntEnum type using Numba's
PrimitiveModel that can be represented natively in the target in all
usage contexts.
"""

def __init__(self, dmm, fe_type):
be_type = llvmir.IntType(fe_type.bitwidth)
super().__init__(dmm, fe_type, be_type)


class USMArrayDeviceModel(StructModel):
"""A data model to represent a usm array type in the LLVM IR generated for a
device-only kernel function.
Expand Down Expand Up @@ -237,7 +250,7 @@ def flattened_field_count(self):


def _init_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a DpexKernelTarget-specific data model manager.
"""Initializes a data model manager used by the SPRIVTarget.
SPIRV kernel functions for certain types of devices require an explicit
address space qualifier for pointers. For OpenCL HD Graphics
Expand All @@ -252,8 +265,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
a dpnp.ndarray object can be passed to any other regular function.
Returns:
DataModelManager: A numba-dpex DpexKernelTarget-specific data model
manager
DataModelManager: A numba-dpex SPIRVTarget-specific data model manager
"""
dmm = datamodel.default_manager.copy()
dmm.register(types.CPointer, GenericPointerModel)
Expand All @@ -271,6 +283,8 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
# model manager. The dpex_data_model_manager is used by the DpexKernelTarget
dmm.register(DpctlSyclQueue, SyclQueueModel)

dmm.register(IntEnumLiteral, IntEnumLiteralModel)

return dmm


Expand Down
6 changes: 4 additions & 2 deletions numba_dpex/core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .array_type import Array
from .dpctl_types import DpctlSyclEvent, DpctlSyclQueue
from .dpnp_ndarray_type import DpnpNdArray
from .kernel_api.literal_intenum import IntEnumLiteral
from .kernel_api.ranges import NdRangeType, RangeType
from .numba_types_short_names import (
b1,
Expand Down Expand Up @@ -36,8 +37,9 @@
"DpctlSyclQueue",
"DpctlSyclEvent",
"DpnpNdArray",
"RangeType",
"IntEnumLiteral",
"NdRangeType",
"RangeType",
"USMNdArray",
"none",
"boolean",
Expand All @@ -57,6 +59,6 @@
"f8",
"float_",
"double",
"void",
"usm_ndarray",
"void",
]
2 changes: 0 additions & 2 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from .decorators import device_func, kernel
from .launcher import call_kernel, call_kernel_async
from .literal_intenum_type import IntEnumLiteral
from .models import *
from .types import KernelDispatcherType

Expand All @@ -41,6 +40,5 @@ def dpex_dispatcher_const(context):
"kernel",
"call_kernel",
"call_kernel_async",
"IntEnumLiteral",
"SPIRVKernelDispatcher",
]
16 changes: 1 addition & 15 deletions numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
numba_dpex.experimental module.
"""

from llvmlite import ir as llvmir
from numba.core import types
from numba.core.datamodel import DataModelManager, models
from numba.core.datamodel.models import PrimitiveModel, StructModel
from numba.core.datamodel.models import StructModel
from numba.core.extending import register_model

import numba_dpex.core.datamodel.models as dpex_core_models
Expand All @@ -20,7 +19,6 @@
)

from .dpcpp_types import AtomicRefType
from .literal_intenum_type import IntEnumLiteral
from .types import KernelDispatcherType


Expand All @@ -37,17 +35,6 @@ def __init__(self, dmm, fe_type):
super().__init__(dmm, fe_type, members)


class IntEnumLiteralModel(PrimitiveModel):
"""Representation of an object of LiteralIntEnum type using Numba's
PrimitiveModel that can be represented natively in the target in all
usage contexts.
"""

def __init__(self, dmm, fe_type):
be_type = llvmir.IntType(fe_type.bitwidth)
super().__init__(dmm, fe_type, be_type)


class EmptyStructModel(StructModel):
"""Data model that does not take space. Intended to be used with types that
are presented only at typing stage and not represented physically."""
Expand All @@ -71,7 +58,6 @@ def _init_exp_data_model_manager() -> DataModelManager:
dmm = dpex_core_models.dpex_data_model_manager.copy()

# Register the types and data model in the DpexExpTargetContext
dmm.register(IntEnumLiteral, IntEnumLiteralModel)
dmm.register(AtomicRefType, AtomicRefModel)

# Register the GroupType type
Expand Down
53 changes: 0 additions & 53 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,15 @@

from functools import cached_property

from llvmlite import ir as llvmir
from numba.core import types
from numba.core.descriptors import TargetDescriptor
from numba.core.target_extension import GPU, target_registry
from numba.core.types.scalars import IntEnumClass

from numba_dpex._kernel_api_impl.spirv.target import (
SPIRVTargetContext,
SPIRVTypingContext,
)
from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.experimental.models import exp_dmm
from numba_dpex.kernel_api.flag_enum import FlagEnum

from .literal_intenum_type import IntEnumLiteral


# pylint: disable=R0903
Expand All @@ -45,37 +39,6 @@ class DpexExpKernelTypingContext(SPIRVTypingContext):
are stable enough to be migrated to DpexKernelTypingContext.
"""

def resolve_value_type(self, val):
"""
Return the numba type of a Python value that is being used
as a runtime constant.
ValueError is raised for unsupported types.
"""

typ = super().resolve_value_type(val)

if isinstance(typ, IntEnumClass) and issubclass(val, FlagEnum):
typ = IntEnumLiteral(val)

return typ

def resolve_getattr(self, typ, attr):
"""
Resolve getting the attribute *attr* (a string) on the Numba type.
The attribute's type is returned, or None if resolution failed.
"""
retty = None

if isinstance(typ, IntEnumLiteral):
try:
attrval = getattr(typ.literal_value, attr).value
retty = types.IntegerLiteral(attrval)
except ValueError:
pass
else:
retty = super().resolve_getattr(typ, attr)
return retty


# pylint: disable=W0223
# FIXME: Remove the pylint disablement once we add an override for
Expand All @@ -95,22 +58,6 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
super().__init__(typingctx, target)
self.data_model_manager = exp_dmm

def get_getattr(self, typ, attr):
"""
Overrides the get_getattr function to provide an implementation for
getattr call on an IntegerEnumLiteral type.
"""

if isinstance(typ, IntEnumLiteral):
# pylint: disable=W0613
def enum_literal_getattr_imp(context, builder, typ, val, attr):
enum_attr_value = getattr(typ.literal_value, attr).value
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)

return enum_literal_getattr_imp

return super().get_getattr(typ, attr)


class DpexExpKernelTarget(TargetDescriptor):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.core.types import IntEnumLiteral
from numba_dpex.kernel_api.flag_enum import FlagEnum


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from numba.core.datamodel import default_manager

from numba_dpex.core.datamodel.models import dpex_data_model_manager
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.models import exp_dmm
from numba_dpex.core.types import IntEnumLiteral
from numba_dpex.kernel_api.flag_enum import FlagEnum


Expand All @@ -24,11 +23,8 @@ class DummyFlags(FlagEnum):
with pytest.raises(KeyError):
default_manager.lookup(dummy)

with pytest.raises(KeyError):
dpex_data_model_manager.lookup(dummy)

try:
exp_dmm.lookup(dummy)
dpex_data_model_manager.lookup(dummy)
except:
pytest.fail(
"IntEnumLiteral type lookup failed in experimental "
Expand Down

0 comments on commit cf4b631

Please sign in to comment.