Skip to content

Commit

Permalink
Merge pull request #1377 from IntelPython/fix/private_array
Browse files Browse the repository at this point in the history
Fix/private array
  • Loading branch information
ZzEeKkAa authored Mar 15, 2024
2 parents 2831862 + 0131838 commit 200bf8d
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _intrinsic_spirv_global_index_const(
sig = types.int64(types.int32)

def _intrinsic_spirv_global_index_const_gen(
context: SPIRVTargetContext,
context: SPIRVTargetContext, # pylint: disable=unused-argument
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
Expand All @@ -79,7 +79,16 @@ def _intrinsic_spirv_global_index_const_gen(
dim,
)

return context.cast(builder, res, types.uintp, types.intp)
# Generating same check as sycl does. Did they add it to avoid pointer
# bitcast on special constant?
max_int32 = llvmir.Constant(res.type, 2147483648)
cmp = builder.icmp_unsigned("<", res, max_int32)

inst = builder.assume(cmp)
# TODO: tail does not always work
inst.tail = "tail"

return res

return sig, _intrinsic_spirv_global_index_const_gen

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils, types
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.core.typing.templates import Signature
from numba.extending import intrinsic, overload
from numba.extending import type_callable

from numba_dpex.core.types import USMNdArray
from numba_dpex.experimental.target import DpexExpKernelTypingContext
Expand All @@ -23,55 +24,12 @@
)
from numba_dpex.utils import address_space as AddressSpace

from ..target import DPEX_KERNEL_EXP_TARGET_NAME
from ._registry import lower


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_private_array_ctor(
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
):
require_literal(ty_shape)

ty_array = USMNdArray(
dtype=_ty_parse_dtype(ty_dtype),
ndim=_ty_parse_shape(ty_shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

sig = ty_array(ty_shape, ty_dtype)

def codegen(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
shape = args[0]
ty_shape = sig.args[0]
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)
return ary._getvalue() # pylint: disable=protected-access

return (
sig,
codegen,
)


@overload(
PrivateArray,
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_private_array_ctor(
shape,
dtype,
):
"""Overload of the constructor for the class
@type_callable(PrivateArray)
def type_interval(context): # pylint: disable=unused-argument
"""Sets type of the constructor for the class
class:`numba_dpex.kernel_api.PrivateArray`.
Raises:
Expand All @@ -81,11 +39,48 @@ def ol_private_array_ctor(
type.
"""

def ol_private_array_ctor_impl(
shape,
dtype,
):
# pylint: disable=no-value-for-parameter
return _intrinsic_private_array_ctor(shape, dtype)
def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)):
require_literal(shape)
require_literal(fill_zeros)

return USMNdArray(
dtype=_ty_parse_dtype(dtype),
ndim=_ty_parse_shape(shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

return typer


@lower(PrivateArray, types.IntegerLiteral, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.Tuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.UniTuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.IntegerLiteral, types.Any)
@lower(PrivateArray, types.Tuple, types.Any)
@lower(PrivateArray, types.UniTuple, types.Any)
def dpex_private_array_lower(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
"""Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`"""
shape = args[0]
ty_shape = sig.args[0]
if len(sig.args) == 3:
fill_zeros = sig.args[-1].literal_value
else:
fill_zeros = False
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)

if fill_zeros:
cgutils.memset(
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
)

return ol_private_array_ctor_impl
return ary._getvalue() # pylint: disable=protected-access
12 changes: 12 additions & 0 deletions numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
"""

from numba.core.imputils import Registry

registry = Registry()
lower = registry.lower
9 changes: 6 additions & 3 deletions numba_dpex/kernel_api/private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
kernel function.
"""

from numpy import ndarray
import numpy as np


class PrivateArray:
Expand All @@ -16,10 +16,13 @@ class PrivateArray:
inside kernel work item.
"""

def __init__(self, shape, dtype) -> None:
def __init__(self, shape, dtype, fill_zeros=False) -> None:
"""Creates a new PrivateArray instance of the given shape and dtype."""

self._data = ndarray(shape=shape, dtype=dtype)
if fill_zeros:
self._data = np.zeros(shape=shape, dtype=dtype)
else:
self._data = np.empty(shape=shape, dtype=dtype)

def __getitem__(self, idx_obj):
"""Returns the value stored at the position represented by idx_obj in
Expand Down
4 changes: 3 additions & 1 deletion numba_dpex/kernel_api_impl/spirv/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def require_literal(literal_type: types.Type):

for i, _ in enumerate(literal_type):
if not isinstance(literal_type[i], types.Literal):
raise errors.TypingError("requires literal type")
raise errors.TypingError(
"requires each element of tuple literal type"
)


def make_spirv_array( # pylint: disable=too-many-arguments
Expand Down
26 changes: 18 additions & 8 deletions numba_dpex/kernel_api_impl/spirv/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Implements a new numba dispatcher class and a compiler class to compile and
call numba_dpex.kernel decorated function.
"""
import hashlib
from collections import namedtuple
from contextlib import ExitStack
from typing import Tuple
Expand Down Expand Up @@ -181,6 +182,9 @@ def _compile_to_spirv(
# all linking libraries getting linked together and final optimization
# including inlining of functions if an inlining level is specified.
kernel_library.finalize()

if config.DUMP_KERNEL_LLVM:
self._dump_kernel(kernel_fndesc, kernel_library)
# Compiled the LLVM IR to SPIR-V
kernel_spirv_module = spirv_generator.llvm_to_spirv(
kernel_targetctx,
Expand Down Expand Up @@ -268,20 +272,26 @@ def _compile_cached(

kcres_attrs.append(kernel_device_ir_module)

if config.DUMP_KERNEL_LLVM:
with open(
cres.fndesc.llvm_func_name + ".ll",
"w",
encoding="UTF-8",
) as fptr:
fptr.write(str(cres.library.final_module))

except errors.TypingError as err:
self._failed_cache[key] = err
return False, err

return True, _SPIRVKernelCompileResult(*kcres_attrs)

def _dump_kernel(self, fndesc, library):
"""Dump kernel into file."""
name = fndesc.llvm_func_name
if len(name) > 200:
sha256 = hashlib.sha256(name.encode("utf-8")).hexdigest()
name = name[:150] + "_" + sha256

with open(
name + ".ll",
"w",
encoding="UTF-8",
) as fptr:
fptr.write(str(library.final_module))


class SPIRVKernelDispatcher(Dispatcher):
"""Dispatcher class designed to compile kernel decorated functions. The
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/kernel_api_impl/spirv/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def finalize(self):
llvm_spirv_args = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
"--spirv-ext=+SPV_INTEL_arbitrary_precision_integers",
]
for key in list(self.context.extra_compile_options.keys()):
if key == LLVM_SPIRV_ARGS:
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,16 @@ def load_additional_registries(self):
# pylint: disable=import-outside-toplevel
from numba_dpex import printimpl
from numba_dpex.dpnp_iface import dpnpimpl
from numba_dpex.experimental._kernel_dpcpp_spirv_overloads._registry import (
registry as spirv_registry,
)
from numba_dpex.ocl import mathimpl, oclimpl

self.insert_func_defn(oclimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
self.insert_func_defn(dpnpimpl.registry.functions)
self.install_registry(printimpl.registry)
self.install_registry(spirv_registry)
# Replace dpnp math functions with their OpenCL versions.
self.replace_dpnp_ufunc_with_ocl_intrinsics()

Expand Down
32 changes: 31 additions & 1 deletion numba_dpex/tests/experimental/test_private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a):
a[i] += p[j]


def private_array_kernel_fill_true(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=True)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_array_kernel_fill_false(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=False)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_2d_array_kernel(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
Expand All @@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a):


@pytest.mark.parametrize(
"kernel", [private_array_kernel, private_2d_array_kernel]
"kernel",
[
private_array_kernel,
private_array_kernel_fill_true,
private_array_kernel_fill_false,
private_2d_array_kernel,
],
)
@pytest.mark.parametrize(
"call_kernel, decorator",
Expand Down

0 comments on commit 200bf8d

Please sign in to comment.