Skip to content

Commit

Permalink
Add fill_zeros to private array
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 6, 2024
1 parent fa4b04d commit 48595d4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.core.typing.templates import Signature
Expand All @@ -31,9 +32,13 @@
inline="always",
)
def _intrinsic_private_array_ctor(
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
ty_context, # pylint: disable=unused-argument
ty_shape,
ty_dtype,
ty_fill_zeros,
):
require_literal(ty_shape)
require_literal(ty_fill_zeros)

ty_array = USMNdArray(
dtype=_ty_parse_dtype(ty_dtype),
Expand All @@ -42,7 +47,7 @@ def _intrinsic_private_array_ctor(
addrspace=AddressSpace.PRIVATE,
)

sig = ty_array(ty_shape, ty_dtype)
sig = ty_array(ty_shape, ty_dtype, ty_fill_zeros)

def codegen(
context: DpexExpKernelTypingContext,
Expand All @@ -52,11 +57,18 @@ def codegen(
):
shape = args[0]
ty_shape = sig.args[0]
ty_fill_zeros = sig.args[-1]
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)

if ty_fill_zeros.literal_value:
cgutils.memset(
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
)

return ary._getvalue() # pylint: disable=protected-access

return (
Expand All @@ -74,6 +86,7 @@ def codegen(
def ol_private_array_ctor(
shape,
dtype,
fill_zeros=False,
):
"""Overload of the constructor for the class
class:`numba_dpex.kernel_api.PrivateArray`.
Expand All @@ -88,8 +101,9 @@ def ol_private_array_ctor(
def ol_private_array_ctor_impl(
shape,
dtype,
fill_zeros=False,
):
# pylint: disable=no-value-for-parameter
return _intrinsic_private_array_ctor(shape, dtype)
return _intrinsic_private_array_ctor(shape, dtype, fill_zeros)

return ol_private_array_ctor_impl
9 changes: 6 additions & 3 deletions numba_dpex/kernel_api/private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
kernel function.
"""

from numpy import ndarray
import numpy as np


class PrivateArray:
Expand All @@ -16,10 +16,13 @@ class PrivateArray:
inside kernel work item.
"""

def __init__(self, shape, dtype) -> None:
def __init__(self, shape, dtype, fill_zeros=False) -> None:
"""Creates a new PrivateArray instance of the given shape and dtype."""

self._data = ndarray(shape=shape, dtype=dtype)
if fill_zeros:
self._data = np.zeros(shape=shape, dtype=dtype)
else:
self._data = np.empty(shape=shape, dtype=dtype)

def __getitem__(self, idx_obj):
"""Returns the value stored at the position represented by idx_obj in
Expand Down
32 changes: 31 additions & 1 deletion numba_dpex/tests/experimental/test_private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a):
a[i] += p[j]


def private_array_kernel_fill_true(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=True)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_array_kernel_fill_false(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=False)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_2d_array_kernel(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
Expand All @@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a):


@pytest.mark.parametrize(
"kernel", [private_array_kernel, private_2d_array_kernel]
"kernel",
[
private_array_kernel,
private_array_kernel_fill_true,
private_array_kernel_fill_false,
private_2d_array_kernel,
],
)
@pytest.mark.parametrize(
"call_kernel, decorator",
Expand Down

0 comments on commit 48595d4

Please sign in to comment.