Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a new literal type to store IntEnum as Literal types. #1227

Merged
merged 7 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions numba_dpex/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,15 @@ def __init__(self, extra_msg=None) -> None:
if extra_msg:
self.message += " due to " + extra_msg
super().__init__(self.message)


class IllegalIntEnumLiteralValueError(Exception):
"""Exception raised when an IntEnumLiteral is attempted to be created from
a non FlagEnum attribute.
"""

def __init__(self) -> None:
self.message = (
"An IntEnumLiteral can only be initialized from a FlagEnum member"
)
super().__init__(self.message)
12 changes: 10 additions & 2 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from numba.core.imputils import Registry

from .decorators import kernel
from .decorators import device_func, kernel
from .kernel_dispatcher import KernelDispatcher
from .launcher import call_kernel, call_kernel_async
from .literal_intenum_type import IntEnumLiteral
from .models import *
from .types import KernelDispatcherType

Expand All @@ -26,4 +27,11 @@ def dpex_dispatcher_const(context):
return context.get_dummy_value()


__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]
__all__ = [
"device_func",
"kernel",
"call_kernel",
"call_kernel_async",
"IntEnumLiteral",
"KernelDispatcher",
]
25 changes: 25 additions & 0 deletions numba_dpex/experimental/flag_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex
intends to use as Integer literal types inside the compiler type inferring
infrastructure.
"""
from enum import IntEnum


class FlagEnum(IntEnum):
"""Helper class to distinguish IntEnum types that numba-dpex should consider
as Numba Literal types.
"""

@classmethod
def basetype(cls) -> int:
"""Returns an dummy int object that helps numba-dpex infer the type of
an instance of a FlagEnum class.

Returns:
int: Dummy int value
"""
return int(0)
6 changes: 3 additions & 3 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,12 @@ def add_overload(self, cres):
args = tuple(cres.signature.args)
self.overloads[args] = cres

def get_overload_device_ir(self, sig):
def get_overload_kcres(self, sig) -> _KernelCompileResult:
"""
Return the compiled device bitcode for the given signature.
Return the compiled function for the given signature.
"""
args, _ = sigutils.normalize_signature(sig)
return self.overloads[tuple(args)].kernel_device_ir_module
return self.overloads[tuple(args)]

def compile(self, sig) -> any:
disp = self._get_dispatcher_for_current_target()
Expand Down
4 changes: 2 additions & 2 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def _submit_kernel(
# codegen
kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher
kernel_dispatcher.compile(kernel_sig)
kernel_module: _KernelModule = kernel_dispatcher.get_overload_device_ir(
kernel_module: _KernelModule = kernel_dispatcher.get_overload_kcres(
kernel_sig
)
).kernel_device_ir_module
kernel_targetctx = kernel_dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
Expand Down
55 changes: 55 additions & 0 deletions numba_dpex/experimental/literal_intenum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Definition of a new Literal type in numba-dpex that allows treating IntEnum
members as integer literals inside a JIT compiled function.
"""
from enum import IntEnum

from numba.core.pythonapi import box
from numba.core.typeconv import Conversion
from numba.core.types import Integer, Literal
from numba.core.typing.typeof import typeof

from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
from numba_dpex.experimental.flag_enum import FlagEnum


class IntEnumLiteral(Literal, Integer):
"""A Literal type for IntEnum objects. The type contains the original Python
value of the IntEnum class in it.
"""

# pylint: disable=W0231
def __init__(self, value):
self._literal_init(value)
self.name = f"Literal[IntEnum]({value})"
if issubclass(value, FlagEnum):
basetype = typeof(value.basetype())
Integer.__init__(
self,
name=self.name,
bitwidth=basetype.bitwidth,
signed=basetype.signed,
)
else:
raise IllegalIntEnumLiteralValueError

def can_convert_to(self, typingctx, other) -> bool:
conv = typingctx.can_convert(self.literal_type, other)
if conv is not None:
return max(conv, Conversion.promote)
return False


Literal.ctor_map[IntEnum] = IntEnumLiteral


@box(IntEnumLiteral)
def box_literal_integer(typ, val, c):
"""Defines how a Numba representation for an IntEnumLiteral object should
be converted to a PyObject* object and returned back to Python.
"""
val = c.context.cast(c.builder, val, typ, typ.literal_type)
return c.box(typ.literal_type, val)
16 changes: 15 additions & 1 deletion numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,28 @@
numba_dpex.experimental module.
"""

from llvmlite import ir as llvmir
from numba.core.datamodel import DataModelManager, models
from numba.core.datamodel.models import PrimitiveModel
from numba.core.extending import register_model

import numba_dpex.core.datamodel.models as dpex_core_models

from .literal_intenum_type import IntEnumLiteral
from .types import KernelDispatcherType


class IntEnumLiteralModel(PrimitiveModel):
"""Representation of an object of LiteralIntEnum type using Numba's
PrimitiveModel that can be represented natively in the target in all
usage contexts.
"""

def __init__(self, dmm, fe_type):
be_type = llvmir.IntType(fe_type.bitwidth)
super().__init__(dmm, fe_type, be_type)


def _init_exp_data_model_manager() -> DataModelManager:
"""Initializes a DpexExpKernelTarget-specific data model manager.

Expand All @@ -28,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager:
dmm = dpex_core_models.dpex_data_model_manager.copy()

# Register the types and data model in the DpexExpTargetContext
# Add here...
dmm.register(IntEnumLiteral, IntEnumLiteralModel)

return dmm

Expand Down
55 changes: 55 additions & 0 deletions numba_dpex/experimental/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

from functools import cached_property

from llvmlite import ir as llvmir
from numba.core import types
from numba.core.descriptors import TargetDescriptor
from numba.core.target_extension import GPU, target_registry
from numba.core.types.scalars import IntEnumClass

from numba_dpex.core.descriptor import DpexTargetOptions
from numba_dpex.core.targets.kernel_target import (
Expand All @@ -18,6 +21,9 @@
)
from numba_dpex.experimental.models import exp_dmm

from .flag_enum import FlagEnum
from .literal_intenum_type import IntEnumLiteral


# pylint: disable=R0903
class SyclDeviceExp(GPU):
Expand All @@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
are stable enough to be migrated to DpexKernelTypingContext.
"""

def resolve_value_type(self, val):
"""
Return the numba type of a Python value that is being used
as a runtime constant.
ValueError is raised for unsupported types.
"""

ty = super().resolve_value_type(val)

if isinstance(ty, IntEnumClass) and issubclass(val, FlagEnum):
ty = IntEnumLiteral(val)

return ty

def resolve_getattr(self, typ, attr):
"""
Resolve getting the attribute *attr* (a string) on the Numba type.
The attribute's type is returned, or None if resolution failed.
"""
ty = None

if isinstance(typ, IntEnumLiteral):
try:
attrval = getattr(typ.literal_value, attr).value
ty = types.IntegerLiteral(attrval)
except ValueError:
pass
else:
ty = super().resolve_getattr(typ, attr)
return ty


# pylint: disable=W0223
# FIXME: Remove the pylint disablement once we add an override for
Expand All @@ -52,10 +89,28 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
they are stable enough to be migrated to DpexKernelTargetContext.
"""

allow_dynamic_globals = True

def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
super().__init__(typingctx, target)
self.data_model_manager = exp_dmm

def get_getattr(self, typ, attr):
"""
Overrides the get_getattr function to provide an implementation for
getattr call on an IntegerEnumLiteral type.
"""

if isinstance(typ, IntEnumLiteral):
# pylint: disable=W0613
def enum_literal_getattr_imp(context, builder, typ, val, attr):
enum_attr_value = getattr(typ.literal_value, attr).value
return llvmir.Constant(llvmir.IntType(64), enum_attr_value)

return enum_literal_getattr_imp

return super().get_getattr(typ, attr)


class DpexExpKernelTarget(TargetDescriptor):
"""
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp

import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
from numba_dpex.experimental.flag_enum import FlagEnum


class MockFlags(FlagEnum):
FLAG1 = 100
FLAG2 = 200


@exp_dpex.kernel(
release_gil=False,
no_compile=True,
no_cpython_wrapper=True,
no_cfunc_wrapper=True,
)
def update_with_flag(a):
a[0] = MockFlags.FLAG1
a[1] = MockFlags.FLAG2


def test_compilation_of_flag_enum():
"""Tests if a FlagEnum subclass can be used inside a kernel function."""
a = dpnp.ones(10, dtype=dpnp.int64)
exp_dpex.call_kernel(update_with_flag, Range(10), a)

assert a[0] == MockFlags.FLAG1
assert a[1] == MockFlags.FLAG2
for idx in range(2, 9):
assert a[idx] == 1
30 changes: 30 additions & 0 deletions numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from enum import IntEnum

import pytest

from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.flag_enum import FlagEnum


def test_intenumliteral_creation():
"""Tests the creation of an IntEnumLiteral type."""

class DummyFlags(FlagEnum):
DUMMY = 0

try:
IntEnumLiteral(DummyFlags)
except:
pytest.fail("Unexpected failure in IntEnumLiteral initialization")

with pytest.raises(IllegalIntEnumLiteralValueError):

class SomeKindOfUnknownEnum(IntEnum):
UNKNOWN_FLAG = 1

IntEnumLiteral(SomeKindOfUnknownEnum)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba.core.datamodel import default_manager

from numba_dpex.core.datamodel.models import dpex_data_model_manager
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.flag_enum import FlagEnum
from numba_dpex.experimental.models import exp_dmm


def test_data_model_registration():
"""Tests that the IntEnumLiteral type is only registered with the
DpexExpKernelTargetContext target.
"""

class DummyFlags(FlagEnum):
DUMMY = 0

dummy = IntEnumLiteral(DummyFlags)

with pytest.raises(KeyError):
default_manager.lookup(dummy)

with pytest.raises(KeyError):
dpex_data_model_manager.lookup(dummy)

try:
exp_dmm.lookup(dummy)
except:
pytest.fail(
"IntEnumLiteral type lookup failed in experimental "
"data model manager"
)
Loading
Loading