Skip to content

Commit

Permalink
WIP: use dpjit specific data model
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 9, 2024
1 parent 8c3cbed commit 048933f
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 62 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ONEAPI_DEVICE_SELECTOR=opencl:cpu
39 changes: 22 additions & 17 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from numba.core import datamodel, types
from numba.core.datamodel.models import PrimitiveModel, StructModel
from numba.core.extending import register_model

from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.utils import address_space
Expand Down Expand Up @@ -244,7 +243,7 @@ def flattened_field_count(self):
return _get_flattened_member_count(self)


def _init_data_model_manager() -> datamodel.DataModelManager:
def _init_kernel_data_model_manager() -> datamodel.DataModelManager:
"""Initializes a DpexKernelTarget-specific data model manager.
SPIRV kernel functions for certain types of devices require an explicit
Expand Down Expand Up @@ -282,25 +281,31 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
return dmm


dpex_data_model_manager = _init_data_model_manager()
def _init_dpjit_data_model_manager() -> datamodel.DataModelManager:
dmm = datamodel.default_manager.copy()

# Register the USMNdArray type to USMArrayDeviceModel in numba's default data
# model manager
dmm.register(USMNdArray, USMArrayHostModel)

# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
# model manager
dmm.register(DpnpNdArray, USMArrayHostModel)

# Register the USMNdArray type to USMArrayDeviceModel in numba's default data
# model manager
register_model(USMNdArray)(USMArrayHostModel)
# Register the DpctlSyclQueue type
dmm.register(DpctlSyclQueue, SyclQueueModel)

# Register the DpnpNdArray type to USMArrayHostModel in numba's default data
# model manager
register_model(DpnpNdArray)(USMArrayHostModel)
# Register the DpctlSyclEvent type
dmm.register(DpctlSyclEvent, SyclEventModel)

# Register the DpctlSyclQueue type
register_model(DpctlSyclQueue)(SyclQueueModel)
# Register the RangeType type
dmm.register(RangeType, RangeModel)

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)
# Register the NdRangeType type
dmm.register(NdRangeType, NdRangeModel)

return dmm

# Register the RangeType type
register_model(RangeType)(RangeModel)

# Register the NdRangeType type
register_model(NdRangeType)(NdRangeModel)
dpex_data_model_manager = _init_kernel_data_model_manager()
dpjit_data_model_manager = _init_dpjit_data_model_manager()
3 changes: 1 addition & 2 deletions numba_dpex/core/kernel_interface/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, errors, types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload

# can't import name because of the circular import
Expand Down Expand Up @@ -281,10 +280,10 @@ def _intrin_ndrange_alloc(
ty_local_range,
ty_ndrange,
)
range_datamodel = default_manager.lookup(ty_global_range)

def codegen(context, builder, sig, args):
typ = sig.return_type
range_datamodel = context.data_model_manager.lookup(ty_global_range)

global_range, local_range, _ = args
ndrange_struct = cgutils.create_struct_proxy(typ)(context, builder)
Expand Down
7 changes: 7 additions & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from numba.core.imputils import Registry
from numba.core.target_extension import CPU, target_registry

from numba_dpex.core.datamodel.models import (
_init_dpjit_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.dpnp_iface import dpnp_ufunc_db


Expand All @@ -37,6 +41,9 @@ def init(self):
self.lower_extensions = {}
super().init()

# self.data_model_manager = _init_dpjit_data_model_manager()
self.data_model_manager = dpjit_data_model_manager

# TODO: initialize nrt once switched to nrt from drt. Most likely we
# call it somewhere. Double check.
# https://github.com/IntelPython/numba-dpex/issues/1175
Expand Down
9 changes: 6 additions & 3 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from numba.core import cgutils, funcdesc, types, typing, utils
from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv
from numba.core.registry import cpu_target
from numba.core.target_extension import GPU, target_registry
from numba.core.types import Array as NpArrayType

from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.core.datamodel.models import (
_init_kernel_data_model_manager,
dpex_data_model_manager,
)
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
from numba_dpex.core.typeconv import to_usm_ndarray
from numba_dpex.core.types import USMNdArray
Expand Down Expand Up @@ -253,7 +255,8 @@ def init(self):
)

# Override data model manager to SPIR model
self.data_model_manager = _init_data_model_manager()
# self.data_model_manager = _init_kernel_data_model_manager()
self.data_model_manager = dpex_data_model_manager
self.extra_compile_options = dict()

from numba_dpex.dpnp_iface.dpnp_ufunc_db import _lazy_init_dpnp_db
Expand Down
5 changes: 3 additions & 2 deletions numba_dpex/core/types/range_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from contextlib import ExitStack

from numba.core import cgutils, errors, types
from numba.core.datamodel import default_manager
from numba.extending import NativeValue, box, unbox

from ..kernel_interface.indexers import NdRange, Range
Expand Down Expand Up @@ -121,7 +120,9 @@ def unbox_ndrange(typ, obj, c):
].value
local_range_struct = ndrange_attr_native_value_map["local_range"].value

range_datamodel = default_manager.lookup(RangeType(typ.ndim))
range_datamodel = c.context.data_model_manager.lookup(
RangeType(typ.ndim)
)
ndrange_struct.ndim = c.builder.extract_value(
global_range_struct,
range_datamodel.get_field_position("ndim"),
Expand Down
3 changes: 1 addition & 2 deletions numba_dpex/dpctl_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dpctl
from llvmlite.ir import IRBuilder
from numba import types
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload, overload_method

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
Expand Down Expand Up @@ -45,7 +44,7 @@ def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):

# defines the custom code generation
def codegen(context, builder, signature, args):
sycl_event_dm = default_manager.lookup(ty_event)
sycl_event_dm = context.data_model_manager.lookup(ty_event)
event_ref = builder.extract_value(
args[0],
sycl_event_dm.get_field_position("event_ref"),
Expand Down
5 changes: 4 additions & 1 deletion numba_dpex/experimental/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba.core.extending import register_model

import numba_dpex.core.datamodel.models as dpex_core_models
from numba_dpex.core.datamodel.models import dpjit_data_model_manager

from .dpcpp_types import AtomicRefType
from .literal_intenum_type import IntEnumLiteral
Expand Down Expand Up @@ -66,4 +67,6 @@ def _init_exp_data_model_manager() -> DataModelManager:
exp_dmm = _init_exp_data_model_manager()

# Register any new type that should go into numba.core.datamodel.default_manager
register_model(KernelDispatcherType)(models.OpaqueModel)
# register_model(KernelDispatcherType)(models.OpaqueModel)
dpjit_data_model_manager.register(KernelDispatcherType)(models.OpaqueModel)
# dmm.register(KernelDispatcherType, models.OpaqueModel)
8 changes: 3 additions & 5 deletions numba_dpex/tests/core/types/DpctlSyclEvent/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
#
# SPDX-License-Identifier: Apache-2.0

import dpctl
from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.datamodel import models

from numba_dpex.core.datamodel.models import (
SyclEventModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.types.dpctl_types import DpctlSyclEvent

Expand All @@ -18,7 +16,7 @@ def test_model_for_DpctlSyclEvent():
default data model manager.
"""
sycl_event = DpctlSyclEvent()
default_model = default_manager.lookup(sycl_event)
default_model = dpjit_data_model_manager.lookup(sycl_event)
assert isinstance(default_model, SyclEventModel)


Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/tests/core/types/DpctlSyclQueue/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def func() -> SyclQueue:
q: SyclQueue = func()

assert len(q.sycl_device.filter_string) > 0


if __name__ == "__main__":
test_boxing_without_parent()
16 changes: 8 additions & 8 deletions numba_dpex/tests/core/types/DpnpNdArray/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# SPDX-License-Identifier: Apache-2.0

from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target
from numba.core.datamodel import models

from numba_dpex.core.datamodel.models import (
USMArrayDeviceModel,
USMArrayHostModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.descriptor import dpex_kernel_target, dpex_target
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray


Expand All @@ -22,7 +22,7 @@ def test_model_for_DpnpNdArray():
dpnp_ndarray = DpnpNdArray(ndim=1, dtype=types.float64, layout="C")
model = dpex_data_model_manager.lookup(dpnp_ndarray)
assert isinstance(model, USMArrayDeviceModel)
default_model = default_manager.lookup(dpnp_ndarray)
default_model = dpjit_data_model_manager.lookup(dpnp_ndarray)
assert isinstance(default_model, USMArrayHostModel)


Expand All @@ -40,15 +40,15 @@ def test_flattened_member_count():
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager
dpex_target_ctx = dpex_target.target_context
kernel_target_ctx = dpex_kernel_target.target_context
dpex_dmm = kernel_target_ctx.data_model_manager

for ndim in range(4):
dty = DpnpNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)
ap = dpex_target_ctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
12 changes: 5 additions & 7 deletions numba_dpex/tests/core/types/USMNdArray/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from numba.core.registry import cpu_target

from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.descriptor import dpex_kernel_target, dpex_target
from numba_dpex.core.types.usm_ndarray_type import USMNdArray


Expand All @@ -13,15 +11,15 @@ def test_flattened_member_count():
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager
dpex_target_ctx = dpex_target.target_context
kernel_target_ctx = dpex_kernel_target.target_context
dpex_dmm = kernel_target_ctx.data_model_manager

for ndim in range(4):
dty = USMNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)
ap = dpex_target_ctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
22 changes: 10 additions & 12 deletions numba_dpex/tests/core/types/range_types/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba.core.datamodel import default_manager
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
NdRangeModel,
RangeModel,
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.descriptor import dpex_kernel_target, dpex_target
from numba_dpex.core.types.range_types import NdRangeType, RangeType

rfields = ["ndim", "dim0", "dim1", "dim2"]
Expand All @@ -30,8 +29,8 @@ def test_datamodel_registration():
dpex_data_model_manager.lookup(range_ty)
dpex_data_model_manager.lookup(ndrange_ty)

default_range_model = default_manager.lookup(range_ty)
default_ndrange_model = default_manager.lookup(ndrange_ty)
default_range_model = dpjit_data_model_manager.lookup(range_ty)
default_ndrange_model = dpjit_data_model_manager.lookup(ndrange_ty)

assert isinstance(default_range_model, RangeModel)
assert isinstance(default_ndrange_model, NdRangeModel)
Expand All @@ -43,7 +42,7 @@ def test_range_model_fields(field):
RangeType
"""
range_ty = RangeType(ndim=1)
dm = default_manager.lookup(range_ty)
dm = dpjit_data_model_manager.lookup(range_ty)
try:
dm.get_field_position(field)
except:
Expand All @@ -56,7 +55,7 @@ def test_ndrange_model_fields(field):
NdRangeType
"""
ndrange_ty = NdRangeType(ndim=1)
dm = default_manager.lookup(ndrange_ty)
dm = dpjit_data_model_manager.lookup(ndrange_ty)
try:
dm.get_field_position(field)
except:
Expand All @@ -69,15 +68,14 @@ def test_flattened_member_count(range_type):
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager
dpjit_target_ctx = dpex_target.target_context
dpjit_dmm = dpjit_target_ctx.data_model_manager

for ndim in range(1, 3):
dty = range_type(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
datamodel = dpjit_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)
ap = dpjit_target_ctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from numba.core.datamodel import default_manager

from numba_dpex.core.datamodel.models import dpex_data_model_manager
from numba_dpex.core.datamodel.models import (
dpex_data_model_manager,
dpjit_data_model_manager,
)
from numba_dpex.experimental import IntEnumLiteral
from numba_dpex.experimental.flag_enum import FlagEnum
from numba_dpex.experimental.models import exp_dmm
Expand All @@ -22,7 +24,7 @@ class DummyFlags(FlagEnum):
dummy = IntEnumLiteral(DummyFlags)

with pytest.raises(KeyError):
default_manager.lookup(dummy)
dpjit_data_model_manager.lookup(dummy)

with pytest.raises(KeyError):
dpex_data_model_manager.lookup(dummy)
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/tests/kernel_tests/test_func_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def kernel_function(a, b):
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)


if __name__ == "__main__":
test_basic()


def test_single_signature():
"""Basic test with single signature"""

Expand Down

0 comments on commit 048933f

Please sign in to comment.