Skip to content

Commit

Permalink
Adds support for specializing a device_func.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Mar 21, 2024
1 parent 800cd00 commit ec5d238
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 81 deletions.
21 changes: 18 additions & 3 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _kernel_dispatcher(pyfunc):
"Argument passed to the kernel decorator is neither a "
"function object, nor a signature. If you are trying to "
"specialize the kernel that takes a single argument, specify "
"the return type as void explicitly."
"the return type as None explicitly."
)
return _kernel_dispatcher(func)

Expand Down Expand Up @@ -132,13 +132,28 @@ def device_func(func_or_sig=None, **options):
)
options["_compilation_mode"] = CompilationMode.DEVICE_FUNC

func, sigs = _parse_func_or_sig(func_or_sig)
for sig in sigs:
if isinstance(sig, str):
raise NotImplementedError(
"Specifying signatures as string is not yet supported"
)

def _kernel_dispatcher(pyfunc):
return dispatcher(
disp: SPIRVKernelDispatcher = dispatcher(
pyfunc=pyfunc,
targetoptions=options,
)

if func_or_sig is None:
if len(sigs) > 0:
with typeinfer.register_dispatcher(disp):
for sig in sigs:
disp.compile(sig)
disp.disable_compile()

return disp

if func is None:
return _kernel_dispatcher

return _kernel_dispatcher(func_or_sig)
Expand Down
139 changes: 61 additions & 78 deletions numba_dpex/tests/kernel_tests/test_func_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,105 +4,88 @@

import dpnp
import numpy as np
import pytest
from numba import int32, int64

import numba_dpex as dpex
from numba_dpex import float32, int32
import numba_dpex.experimental as dpex

single_signature = dpex.func(int32(int32))
list_signature = dpex.func([int32(int32), float32(float32)])
i32_signature = dpex.device_func(int32(int32))
i32i64_signature = dpex.device_func([int32(int32), int64(int64)])

# Array size
N = 10
N = 1024


def increment(a):
return a + dpnp.float32(1)
return a + 1


def test_basic():
"""Basic test with device func"""
fi32 = i32_signature(increment)
fi32i64 = i32i64_signature(increment)

f = dpex.func(increment)

def kernel_function(a, b):
"""Kernel function that applies f() in parallel"""
i = dpex.get_global_id(0)
b[i] = f(a[i])
@dpex.kernel
def kernel_function(item, a, b):
"""Kernel function that calls fi32()"""
i = item.get_id(0)
b[i] = fi32(a[i])

k = dpex.kernel(kernel_function)

a = dpnp.ones(N)
b = dpnp.ones(N)
@dpex.kernel
def kernel_function2(item, a, b):
"""Kernel function that calls fi32i64()"""
i = item.get_id(0)
b[i] = fi32i64(a[i])

dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)


def test_single_signature():
"""Basic test with single signature"""

fi32 = single_signature(increment)

def kernel_function(a, b):
"""Kernel function that applies fi32() in parallel"""
i = dpex.get_global_id(0)
b[i] = fi32(a[i])

k = dpex.kernel(kernel_function)

# Test with int32, should work
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)

dpex.call_kernel(k, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

# Test with int64, should fail
a = dpnp.ones(N, dtype=dpnp.int64)
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]


def test_list_signature():
"""Basic test with list signature"""

fi32f32 = list_signature(increment)

def kernel_function(a, b):
"""Kernel function that applies fi32f32() in parallel"""
i = dpex.get_global_id(0)
b[i] = fi32f32(a[i])

k = dpex.kernel(kernel_function)

# Test with int32, should work
def test_calling_specialized_device_func():
"""Tests if a specialized device_func gets called as expected from kernel"""
a = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.ones(N, dtype=dpnp.int32)
b = dpnp.zeros(N, dtype=dpnp.int32)

dpex.call_kernel(k, dpex.Range(N), a, b)
dpex.call_kernel(kernel_function, dpex.Range(N), a, b)

assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)

# Test with float32, should work
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.ones(N, dtype=dpnp.float32)

dpex.call_kernel(k, dpex.Range(N), a, b)
def test_calling_specialized_device_func_wrong_signature():
"""Tests that calling specialized signature with wrong signature does not
trigger recompilation.
assert np.array_equal(dpnp.asnumpy(b), dpnp.asnumpy(a) + 1)
Tests kernel_function with float32. Numba will downcast float32 to int32
and call the specialized function. The implicit casting is a problem, but
for the purpose of this test case, all we care is to check if the
specialized function was called and we did not recompiled the device_func.
Refer: https://github.com/numba/numba/issues/9506
"""
# Test with int64, should fail
a = dpnp.ones(N, dtype=dpnp.int64)
b = dpnp.ones(N, dtype=dpnp.int64)

with pytest.raises(Exception) as e:
dpex.call_kernel(k, dpex.Range(N), a, b)

assert " >>> <unknown function>(int64)" in e.value.args[0]
a = dpnp.full(N, 1.5, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.float32)

dpex.call_kernel(kernel_function, dpex.Range(N), a, b)

# Since Numba is calling the i32 specialization of increment, the values in
# `a` are first down converted to int32, *i.e.*, 1.5 to 1 and then
# incremented. Thus, the output is 2 instead of 2.5.
# The implicit down casting is a dangerous thing for Numba to do, but we use
# to our advantage to test if re compilation did not happen for a
# specialized device function.
assert np.all(dpnp.asnumpy(b) == 2)
assert not np.all(dpnp.asnumpy(b) == 2.5)


def test_multi_specialized_device_func():
"""Tests if a device_func with multiple specialization can be called
in a kernel
"""
# Test with int32, i64 should work
ai32 = dpnp.ones(N, dtype=dpnp.int32)
bi32 = dpnp.ones(N, dtype=dpnp.int32)
ai64 = dpnp.ones(N, dtype=dpnp.int64)
bi64 = dpnp.ones(N, dtype=dpnp.int64)

dpex.call_kernel(kernel_function2, dpex.Range(N), ai32, bi32)
dpex.call_kernel(kernel_function2, dpex.Range(N), ai64, bi64)

assert np.array_equal(dpnp.asnumpy(bi32), dpnp.asnumpy(ai32) + 1)
assert np.array_equal(dpnp.asnumpy(bi64), dpnp.asnumpy(ai64) + 1)

0 comments on commit ec5d238

Please sign in to comment.