Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PrivateArray kernel_api #1370

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_atomic_ref_overloads,
_group_barrier_overloads,
_index_space_id_overloads,
_private_array_overloads,
)
from .decorators import device_func, kernel
from .launcher import call_kernel, call_kernel_async
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
"""


import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.core.typing.templates import Signature
from numba.extending import intrinsic, overload

from numba_dpex.core.types import USMNdArray
from numba_dpex.experimental.target import DpexExpKernelTypingContext
from numba_dpex.kernel_api import PrivateArray
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
make_spirv_generic_array_on_stack,
require_literal,
)
from numba_dpex.utils import address_space as AddressSpace

from ..target import DPEX_KERNEL_EXP_TARGET_NAME


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_private_array_ctor(
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
):
require_literal(ty_shape)

ty_array = USMNdArray(
dtype=_ty_parse_dtype(ty_dtype),
ndim=_ty_parse_shape(ty_shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

sig = ty_array(ty_shape, ty_dtype)

def codegen(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
shape = args[0]
ty_shape = sig.args[0]
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)
return ary._getvalue() # pylint: disable=protected-access

return (
sig,
codegen,
)


@overload(
PrivateArray,
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_private_array_ctor(
shape,
dtype,
):
"""Overload of the constructor for the class
class:`numba_dpex.kernel_api.PrivateArray`.

Raises:
errors.TypingError: If the shape argument is not a shape compatible
type.
errors.TypingError: If the dtype argument is not a dtype compatible
type.
"""

def ol_private_array_ctor_impl(
shape,
dtype,
):
# pylint: disable=no-value-for-parameter
return _intrinsic_private_array_ctor(shape, dtype)

return ol_private_array_ctor_impl
2 changes: 2 additions & 0 deletions numba_dpex/kernel_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .index_space_ids import Group, Item, NdItem
from .launcher import call_kernel
from .memory_enums import AddressSpace, MemoryOrder, MemoryScope
from .private_array import PrivateArray
from .ranges import NdRange, Range

__all__ = [
Expand All @@ -28,6 +29,7 @@
"Group",
"NdItem",
"Item",
"PrivateArray",
"group_barrier",
"call_kernel",
]
36 changes: 36 additions & 0 deletions numba_dpex/kernel_api/private_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Implements a simple array intended to be used inside kernel work item.
Implementation is intended to be used in pure Python code when prototyping a
kernel function.
"""

from numpy import ndarray


class PrivateArray:
"""
The ``PrivateArray`` class is an simple version of array intended to be used
inside kernel work item.
"""

def __init__(self, shape, dtype) -> None:
"""Creates a new PrivateArray instance of the given shape and dtype."""

self._data = ndarray(shape=shape, dtype=dtype)

def __getitem__(self, idx_obj):
"""Returns the value stored at the position represented by idx_obj in
the self._data ndarray.
"""

return self._data[idx_obj]

def __setitem__(self, idx_obj, val):
"""Assigns a new value to the position represented by idx_obj in
the self._data ndarray.
"""

self._data[idx_obj] = val
121 changes: 121 additions & 0 deletions numba_dpex/kernel_api_impl/spirv/arrayobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Contains SPIR-V specific array functions."""

import operator
from functools import reduce
from typing import Union

import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils, errors, types
from numba.core.base import BaseContext

from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.ocl.oclimpl import _get_target_data


def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
"""
Return the item size for the given array or buffer type.
Same as numba.np.arrayobj.get_itemsize, but using spirv data.
"""
targetdata = _get_target_data(context)
lldtype = context.get_data_type(array_type.dtype)
return lldtype.get_abi_size(targetdata)


def require_literal(literal_type: types.Type):
"""Checks if the numba type is Literal. If iterable object is passed,
checks that every element is Literal.

Raises:
TypingError: When argument is not Iterable.
"""
if not hasattr(literal_type, "__len__"):
if not isinstance(literal_type, types.Literal):
raise errors.TypingError("requires literal type")
return

for i, _ in enumerate(literal_type):
if not isinstance(literal_type[i], types.Literal):
raise errors.TypingError("requires literal type")


def make_spirv_array( # pylint: disable=too-many-arguments
context: SPIRVTargetContext,
builder: IRBuilder,
ty_array: types.Array,
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
shape: llvmir.Value,
data: llvmir.Value,
):
"""Makes SPIR-V array and fills it data."""
# Create array object
ary = context.make_array(ty_array)(context, builder)

itemsize = get_itemsize(context, ty_array)
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
ll_itemsize = cgutils.intp_t(itemsize)

if isinstance(ty_shape, types.BaseTuple):
shapes = cgutils.unpack_tuple(builder, shape)
else:
ty_shape = (ty_shape,)
shapes = (shape,)
shapes = [
context.cast(builder, value, fromty, types.intp)
for fromty, value in zip(ty_shape, shapes)
]

off = ll_itemsize
strides = []
if ty_array.layout == "F":
for s in shapes:
strides.append(off)
off = builder.mul(off, s)
else:
for s in reversed(shapes):
strides.append(off)
off = builder.mul(off, s)
strides.reverse()

context.populate_array(
ary,
data=data,
shape=shapes,
strides=strides,
itemsize=ll_itemsize,
)

return ary


def allocate_array_data_on_stack(
context: BaseContext,
builder: IRBuilder,
ty_array: types.Array,
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
):
"""Allocates flat array of given shape on the stack."""
if not isinstance(ty_shape, types.BaseTuple):
ty_shape = (ty_shape,)

return cgutils.alloca_once(
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
builder,
context.get_data_type(ty_array.dtype),
size=reduce(operator.mul, [s.literal_value for s in ty_shape]),
)


def make_spirv_generic_array_on_stack(
context: SPIRVTargetContext,
builder: IRBuilder,
ty_array: types.Array,
ty_shape: Union[types.IntegerLiteral, types.BaseTuple],
shape: llvmir.Value,
):
"""Makes SPIR-V array of given shape with empty data."""
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape)
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data)
54 changes: 54 additions & 0 deletions numba_dpex/tests/experimental/test_private_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import numpy as np
import pytest

import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import Item, PrivateArray, Range
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel


def private_array_kernel(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_2d_array_kernel(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(shape=(5, 2), dtype=a.dtype)

for j in range(10):
p[j % 5, j // 5] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j % 5, j // 5]


@pytest.mark.parametrize(
"kernel", [private_array_kernel, private_2d_array_kernel]
)
@pytest.mark.parametrize(
"call_kernel, decorator",
[(dpex_exp.call_kernel, dpex_exp.kernel), (kapi_call_kernel, lambda a: a)],
)
def test_private_array(call_kernel, decorator, kernel):
kernel = decorator(kernel)

a = dpnp.empty(10, dtype=dpnp.float32)
call_kernel(kernel, Range(a.size), a)

# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)

assert np.array_equal(want, a.asnumpy())
Loading