Skip to content

Commit

Permalink
Merge pull request #1227 from IntelPython/experimental/inteumliteral
Browse files Browse the repository at this point in the history
Adds a new literal type to store IntEnum as Literal types.
  • Loading branch information
Diptorup Deb authored Dec 9, 2023
2 parents f1be213 + b1ac8d6 commit a92e9b3
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 8 deletions.
12 changes: 12 additions & 0 deletions numba_dpex/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,15 @@ def __init__(self, extra_msg=None) -> None:
if extra_msg:
self.message += " due to " + extra_msg
super().__init__(self.message)


class IllegalIntEnumLiteralValueError(Exception):
"""Exception raised when an IntEnumLiteral is attempted to be created from
a non FlagEnum attribute.
"""

def __init__(self) -> None:
self.message = (
"An IntEnumLiteral can only be initialized from a FlagEnum member"
)
super().__init__(self.message)
12 changes: 10 additions & 2 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from numba.core.imputils import Registry

from .decorators import kernel
from .decorators import device_func, kernel
from .kernel_dispatcher import KernelDispatcher
from .launcher import call_kernel, call_kernel_async
from .literal_intenum_type import IntEnumLiteral
from .models import *
from .types import KernelDispatcherType

Expand All @@ -26,4 +27,11 @@ def dpex_dispatcher_const(context):
return context.get_dummy_value()


__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]
__all__ = [
"device_func",
"kernel",
"call_kernel",
"call_kernel_async",
"IntEnumLiteral",
"KernelDispatcher",
]
25 changes: 25 additions & 0 deletions numba_dpex/experimental/flag_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex
intends to use as Integer literal types inside the compiler type inferring
infrastructure.
"""
from enum import IntEnum


class FlagEnum(IntEnum):
"""Helper class to distinguish IntEnum types that numba-dpex should consider
as Numba Literal types.
"""

@classmethod
def basetype(cls) -> int:
"""Returns an dummy int object that helps numba-dpex infer the type of
an instance of a FlagEnum class.
Returns:
int: Dummy int value
"""
return int(0)
6 changes: 3 additions & 3 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,12 @@ def add_overload(self, cres):
args = tuple(cres.signature.args)
self.overloads[args] = cres

def get_overload_device_ir(self, sig):
def get_overload_kcres(self, sig) -> _KernelCompileResult:
"""
Return the compiled device bitcode for the given signature.
Return the compiled function for the given signature.
"""
args, _ = sigutils.normalize_signature(sig)
return self.overloads[tuple(args)].kernel_device_ir_module
return self.overloads[tuple(args)]

def compile(self, sig) -> any:
disp = self._get_dispatcher_for_current_target()
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def _submit_kernel(
# codegen
kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher
kernel_dispatcher.compile(kernel_sig)
kernel_module: _KernelModule = kernel_dispatcher.get_overload_device_ir(
kernel_module: _KernelModule = kernel_dispatcher.get_overload_kcres(
kernel_sig
)
).kernel_device_ir_module
kernel_targetctx = kernel_dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
Expand Down
55 changes: 55 additions & 0 deletions numba_dpex/experimental/literal_intenum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Definition of a new Literal type in numba-dpex that allows treating IntEnum
members as integer literals inside a JIT compiled function.
"""
from enum import IntEnum

from numba.core.pythonapi import box
from numba.core.typeconv import Conversion
from numba.core.types import Integer, Literal
from numba.core.typing.typeof import typeof

from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
from numba_dpex.experimental.flag_enum import FlagEnum


class IntEnumLiteral(Literal, Integer):
"""A Literal type for IntEnum objects. The type contains the original Python
value of the IntEnum class in it.
"""

# pylint: disable=W0231
def __init__(self, value):
self._literal_init(value)
self.name = f"Literal[IntEnum]({value})"
if issubclass(value, FlagEnum):
basetype = typeof(value.basetype())
Integer.__init__(
self,
name=self.name,
bitwidth=basetype.bitwidth,
signed=basetype.signed,
)
else:
raise IllegalIntEnumLiteralValueError

def can_convert_to(self, typingctx, other) -> bool:
conv = typingctx.can_convert(self.literal_type, other)
if conv is not None:
return max(conv, Conversion.promote)
return False


Literal.ctor_map[IntEnum] = IntEnumLiteral


@box(IntEnumLiteral)
def box_literal_integer(typ, val, c):
"""Defines how a Numba representation for an IntEnumLiteral object should
be converted to a PyObject* object and returned back to Python.
"""
val = c.context.cast(c.builder, val, typ, typ.literal_type)
return c.box(typ.literal_type, val)
16 changes: 15 additions & 1 deletion numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,28 @@
numba_dpex.experimental module.
"""

from llvmlite import ir as llvmir
from numba.core.datamodel import DataModelManager, models
from numba.core.datamodel.models import PrimitiveModel
from numba.core.extending import register_model

import numba_dpex.core.datamodel.models as dpex_core_models

from .literal_intenum_type import IntEnumLiteral
from .types import KernelDispatcherType


class IntEnumLiteralModel(PrimitiveModel):
"""Representation of an object of LiteralIntEnum type using Numba's
PrimitiveModel that can be represented natively in the target in all
usage contexts.
"""

def __init__(self, dmm, fe_type):
be_type = llvmir.IntType(fe_type.bitwidth)
super().__init__(dmm, fe_type, be_type)


def _init_exp_data_model_manager() -> DataModelManager:
"""Initializes a DpexExpKernelTarget-specific data model manager.
Expand All @@ -28,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager:
dmm = dpex_core_models.dpex_data_model_manager.copy()

# Register the types and data model in the DpexExpTargetContext
# Add here...
dmm.register(IntEnumLiteral, IntEnumLiteralModel)

return dmm

Expand Down
55 changes: 55 additions & 0 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

from functools import cached_property

from llvmlite import ir as llvmir
from numba.core import types
from numba.core.descriptors import TargetDescriptor
from numba.core.target_extension import GPU, target_registry
from numba.core.types.scalars import IntEnumClass

from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.core.targets.kernel_target import (
Expand All @@ -18,6 +21,9 @@
)
from numba_dpex.experimental.models import exp_dmm

from .flag_enum import FlagEnum
from .literal_intenum_type import IntEnumLiteral


# pylint: disable=R0903
class SyclDeviceExp(GPU):
Expand All @@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
are stable enough to be migrated to DpexKernelTypingContext.
"""

def resolve_value_type(self, val):
"""
Return the numba type of a Python value that is being used
as a runtime constant.
ValueError is raised for unsupported types.
"""

ty = super().resolve_value_type(val)

if isinstance(ty, IntEnumClass) and issubclass(val, FlagEnum):
ty = IntEnumLiteral(val)

return ty

def resolve_getattr(self, typ, attr):
"""
Resolve getting the attribute *attr* (a string) on the Numba type.
The attribute's type is returned, or None if resolution failed.
"""
ty = None

if isinstance(typ, IntEnumLiteral):
try:
attrval = getattr(typ.literal_value, attr).value
ty = types.IntegerLiteral(attrval)
except ValueError:
pass
else:
ty = super().resolve_getattr(typ, attr)
return ty


# pylint: disable=W0223
# FIXME: Remove the pylint disablement once we add an override for
Expand All @@ -52,10 +89,28 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
they are stable enough to be migrated to DpexKernelTargetContext.
"""

allow_dynamic_globals = True

def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
super().__init__(typingctx, target)
self.data_model_manager = exp_dmm

def get_getattr(self, typ, attr):
"""
Overrides the get_getattr function to provide an implementation for
getattr call on an IntegerEnumLiteral type.
"""

if isinstance(typ, IntEnumLiteral):
# pylint: disable=W0613
def enum_literal_getattr_imp(context, builder, typ, val, attr):
enum_attr_value = getattr(typ.literal_value, attr).value
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)

return enum_literal_getattr_imp

return super().get_getattr(typ, attr)


class DpexExpKernelTarget(TargetDescriptor):
"""
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp

import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
from numba_dpex.experimental.flag_enum import FlagEnum


class MockFlags(FlagEnum):
FLAG1 = 100
FLAG2 = 200


@exp_dpex.kernel(
release_gil=False,
no_compile=True,
no_cpython_wrapper=True,
no_cfunc_wrapper=True,
)
def update_with_flag(a):
a[0] = MockFlags.FLAG1
a[1] = MockFlags.FLAG2


def test_compilation_of_flag_enum():
"""Tests if a FlagEnum subclass can be used inside a kernel function."""
a = dpnp.ones(10, dtype=dpnp.int64)
exp_dpex.call_kernel(update_with_flag, Range(10), a)

assert a[0] == MockFlags.FLAG1
assert a[1] == MockFlags.FLAG2
for idx in range(2, 9):
assert a[idx] == 1
30 changes: 30 additions & 0 deletions numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from enum import IntEnum

import pytest

from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.flag_enum import FlagEnum


def test_intenumliteral_creation():
"""Tests the creation of an IntEnumLiteral type."""

class DummyFlags(FlagEnum):
DUMMY = 0

try:
IntEnumLiteral(DummyFlags)
except:
pytest.fail("Unexpected failure in IntEnumLiteral initialization")

with pytest.raises(IllegalIntEnumLiteralValueError):

class SomeKindOfUnknownEnum(IntEnum):
UNKNOWN_FLAG = 1

IntEnumLiteral(SomeKindOfUnknownEnum)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba.core.datamodel import default_manager

from numba_dpex.core.datamodel.models import dpex_data_model_manager
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.flag_enum import FlagEnum
from numba_dpex.experimental.models import exp_dmm


def test_data_model_registration():
"""Tests that the IntEnumLiteral type is only registered with the
DpexExpKernelTargetContext target.
"""

class DummyFlags(FlagEnum):
DUMMY = 0

dummy = IntEnumLiteral(DummyFlags)

with pytest.raises(KeyError):
default_manager.lookup(dummy)

with pytest.raises(KeyError):
dpex_data_model_manager.lookup(dummy)

try:
exp_dmm.lookup(dummy)
except:
pytest.fail(
"IntEnumLiteral type lookup failed in experimental "
"data model manager"
)
Loading

0 comments on commit a92e9b3

Please sign in to comment.