-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
""" | ||
Implements the SPIR-V overloads for the kernel_api.AtomicRef class methods. | ||
""" | ||
|
||
|
||
import llvmlite.ir as llvmir | ||
from llvmlite.ir.builder import IRBuilder | ||
from numba.core import errors, types | ||
from numba.core.typing.templates import Signature | ||
from numba.extending import intrinsic, overload | ||
|
||
from numba_dpex.core.types import USMNdArray | ||
from numba_dpex.experimental.target import DpexExpKernelTypingContext | ||
from numba_dpex.kernel_api import PrivateArray | ||
from numba_dpex.kernel_api_impl.spirv.arrayobj import ( | ||
make_spirv_generic_array_on_stack, | ||
require_literal, | ||
) | ||
from numba_dpex.utils import address_space as AddressSpace | ||
|
||
from ..target import DPEX_KERNEL_EXP_TARGET_NAME | ||
|
||
|
||
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) | ||
def _intrinsic_private_array_ctor( | ||
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument | ||
): | ||
require_literal(ty_shape) | ||
|
||
if not isinstance(ty_dtype, types.DType): | ||
raise errors.TypingError("Second argument must be instance of DType") | ||
|
||
ndim = 1 | ||
if hasattr(ty_shape, "__len__"): | ||
ndim = len(ty_shape) | ||
|
||
ty_array = USMNdArray( | ||
dtype=ty_dtype.dtype, | ||
ndim=ndim, | ||
layout="C", | ||
addrspace=AddressSpace.PRIVATE, | ||
) | ||
|
||
sig = ty_array(ty_shape, ty_dtype) | ||
|
||
def codegen( | ||
context: DpexExpKernelTypingContext, | ||
builder: IRBuilder, | ||
sig: Signature, | ||
args: list[llvmir.Value], | ||
): | ||
shape = args[0] | ||
ty_shape = sig.args[0] | ||
ty_array = sig.return_type | ||
|
||
ary = make_spirv_generic_array_on_stack( | ||
context, builder, ty_array, ty_shape, shape | ||
) | ||
return ary._getvalue() # pylint: disable=protected-access | ||
|
||
return ( | ||
sig, | ||
codegen, | ||
) | ||
|
||
|
||
@overload( | ||
PrivateArray, | ||
prefer_literal=True, | ||
target=DPEX_KERNEL_EXP_TARGET_NAME, | ||
) | ||
def ol_private_array_ctor( | ||
shape, | ||
dtype, | ||
): | ||
"""Overload of the constructor for the class | ||
class:`numba_dpex.kernel_api.AtomicRef`. | ||
Raises: | ||
errors.TypingError: If the `ref` argument is not a UsmNdArray type. | ||
errors.TypingError: If the dtype of the `ref` is not supported in an | ||
AtomicRef. | ||
errors.TypingError: If the device does not support atomic operations on | ||
the dtype of the `ref`. | ||
errors.TypingError: If the `memory_order`, `address_type`, or | ||
`memory_scope` arguments could not be parsed as integer literals. | ||
errors.TypingError: If the `address_space` argument is different from | ||
the address space attribute of the `ref` argument. | ||
errors.TypingError: If the address space is PRIVATE. | ||
""" | ||
|
||
def ol_private_array_ctor_impl( | ||
shape, | ||
dtype, | ||
): | ||
# pylint: disable=no-value-for-parameter | ||
return _intrinsic_private_array_ctor(shape, dtype) | ||
|
||
return ol_private_array_ctor_impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Implements a Python analogue to SYCL's local_accessor class. The class is | ||
intended to be used in pure Python code when prototyping a kernel function | ||
and to be passed to an actual kernel function for local memory allocation. | ||
""" | ||
|
||
from dpctl.tensor import usm_ndarray | ||
|
||
KernelUseOnlyError = NotImplementedError("Only for use inside kernel") | ||
|
||
|
||
class PrivateArray(usm_ndarray): | ||
""" | ||
The ``LocalAccessor`` class is analogous to SYCL's ``local_accessor`` | ||
class. The class acts a s proxy to allocating device local memory and | ||
accessing that memory from within a :func:`numba_dpex.kernel` decorated | ||
function. | ||
""" | ||
|
||
def __init__(self, shape, dtype) -> None: | ||
"""Creates a new LocalAccessor instance of the given shape and dtype.""" | ||
|
||
raise KernelUseOnlyError | ||
|
||
def __getitem__(self, idx_obj): | ||
"""Returns the value stored at the position represented by idx_obj in | ||
the self._data ndarray. | ||
""" | ||
|
||
raise KernelUseOnlyError | ||
|
||
def __setitem__(self, idx_obj, val): | ||
"""Assigns a new value to the position represented by idx_obj in | ||
the self._data ndarray. | ||
""" | ||
|
||
raise KernelUseOnlyError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Contains spriv specific array functions.""" | ||
|
||
import operator | ||
from functools import reduce | ||
from typing import Union | ||
|
||
import llvmlite.ir as llvmir | ||
from llvmlite.ir.builder import IRBuilder | ||
from numba.core import cgutils, errors, types | ||
from numba.core.base import BaseContext | ||
|
||
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext | ||
from numba_dpex.ocl.oclimpl import _get_target_data | ||
|
||
|
||
def get_itemsize(context: SPIRVTargetContext, array_type: types.Array): | ||
""" | ||
Return the item size for the given array or buffer type. | ||
Same as numba.np.arrayobj.get_itemsize, but using spirv data. | ||
""" | ||
targetdata = _get_target_data(context) | ||
lldtype = context.get_data_type(array_type.dtype) | ||
return lldtype.get_abi_size(targetdata) | ||
|
||
|
||
def require_literal(literal_type: types.Type): | ||
"""Checks if the numba type is Literal. If iterable object is passed, | ||
checks that every element is Literal. | ||
Raises: | ||
TypingError: When argument is not Iterable. | ||
""" | ||
if not hasattr(literal_type, "__len__"): | ||
if not isinstance(literal_type, types.Literal): | ||
raise errors.TypingError("requires literal type") | ||
return | ||
|
||
for i, _ in enumerate(literal_type): | ||
if not isinstance(literal_type[i], types.Literal): | ||
raise errors.TypingError("requires literal type") | ||
|
||
|
||
def make_spirv_array( # pylint: disable=too-many-arguments | ||
context: SPIRVTargetContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
shape: llvmir.Value, | ||
data: llvmir.Value, | ||
): | ||
"""Makes SPIR-V array and fills it data.""" | ||
# Create array object | ||
ary = context.make_array(ty_array)(context, builder) | ||
|
||
itemsize = get_itemsize(context, ty_array) | ||
ll_itemsize = cgutils.intp_t(itemsize) | ||
|
||
if isinstance(ty_shape, types.BaseTuple): | ||
shapes = cgutils.unpack_tuple(builder, shape) | ||
else: | ||
ty_shape = (ty_shape,) | ||
shapes = (shape,) | ||
shapes = [ | ||
context.cast(builder, value, fromty, types.intp) | ||
for fromty, value in zip(ty_shape, shapes) | ||
] | ||
|
||
off = ll_itemsize | ||
strides = [] | ||
if ty_array.layout == "F": | ||
for s in shapes: | ||
strides.append(off) | ||
off = builder.mul(off, s) | ||
else: | ||
for s in reversed(shapes): | ||
strides.append(off) | ||
off = builder.mul(off, s) | ||
strides.reverse() | ||
|
||
context.populate_array( | ||
ary, | ||
data=data, | ||
shape=shapes, | ||
strides=strides, | ||
itemsize=ll_itemsize, | ||
) | ||
|
||
return ary | ||
|
||
|
||
def allocate_array_data_on_stack( | ||
context: BaseContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
): | ||
"""Allocates flat array of given shape on the stack.""" | ||
if not isinstance(ty_shape, types.BaseTuple): | ||
ty_shape = (ty_shape,) | ||
|
||
return cgutils.alloca_once( | ||
builder, | ||
context.get_data_type(ty_array.dtype), | ||
size=reduce(operator.mul, [s.literal_value for s in ty_shape]), | ||
) | ||
|
||
|
||
def make_spirv_generic_array_on_stack( | ||
context: SPIRVTargetContext, | ||
builder: IRBuilder, | ||
ty_array: types.Array, | ||
ty_shape: Union[types.IntegerLiteral, types.BaseTuple], | ||
shape: llvmir.Value, | ||
): | ||
"""Makes SPIR-V array of given shape with empty data.""" | ||
data = allocate_array_data_on_stack(context, builder, ty_array, ty_shape) | ||
return make_spirv_array(context, builder, ty_array, ty_shape, shape, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import dpnp | ||
import numpy as np | ||
import pytest | ||
|
||
import numba_dpex.experimental as dpex_exp | ||
from numba_dpex.kernel_api import Item, PrivateArray, Range | ||
|
||
|
||
@dpex_exp.kernel | ||
def private_array_kernel(item: Item, a): | ||
i = item.get_linear_id() | ||
p = PrivateArray(10, a.dtype) | ||
|
||
for j in range(10): | ||
p[j] = j * j | ||
|
||
a[i] = 0 | ||
for j in range(10): | ||
a[i] += p[j] | ||
|
||
|
||
@dpex_exp.kernel | ||
def private_2d_array_kernel(item: Item, a): | ||
i = item.get_linear_id() | ||
p = PrivateArray(shape=(5, 2), dtype=a.dtype) | ||
|
||
for j in range(10): | ||
p[j % 5, j // 5] = j * j | ||
|
||
a[i] = 0 | ||
for j in range(10): | ||
a[i] += p[j % 5, j // 5] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kernel", [private_array_kernel, private_2d_array_kernel] | ||
) | ||
def test_private_array(kernel): | ||
a = dpnp.empty(10, dtype=dpnp.float32) | ||
dpex_exp.call_kernel(kernel, Range(a.size), a) | ||
|
||
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32) | ||
|
||
assert np.array_equal(want, a.asnumpy()) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_private_array(private_array_kernel) |