Skip to content

Commit

Permalink
Merge pull request #1285 from chudur-budur/fix/test-strided-kernel-dp…
Browse files Browse the repository at this point in the history
…np-array

Refactor all tests for strided dpnp arrays in kernel with different layouts
  • Loading branch information
Diptorup Deb authored Jan 22, 2024
2 parents 12ebcd3 + 52f4ed1 commit 6b483fd
Showing 1 changed file with 228 additions and 55 deletions.
283 changes: 228 additions & 55 deletions numba_dpex/tests/experimental/test_strided_dpnp_array_in_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,281 @@
#
# SPDX-License-Identifier: Apache-2.0

import math

import dpnp
import numpy as np
import pytest

import numba_dpex
import numba_dpex.experimental as exp_dpex
from numba_dpex import NdRange, Range
from numba_dpex import Range


def get_order(a):
"""Get order of an array.
Args:
a (numpy.ndarray, dpnp.ndarray): Input array.
Raises:
Exception: _description_
Returns:
str: 'C' if c-contiguous, 'F' if f-contiguous or 'A' if aligned.
"""
if a.flags.c_contiguous and not a.flags.f_contiguous:
return "C"
elif not a.flags.c_contiguous and a.flags.f_contiguous:
return "F"
elif a.flags.c_contiguous and a.flags.f_contiguous:
return "A"
else:
raise Exception("Unknown order/layout")


@exp_dpex.kernel
def change_values_1d(x, v):
"""Assign values in a 1d dpnp.ndarray
Args:
x (dpnp.ndarray): Input array.
v (int): Value to be assigned.
"""
i = numba_dpex.get_global_id(0)
p = x[i] # getitem
p = v
x[i] = p # setitem
x[i] = v


def change_values_1d_func(a, p):
"""Assign values in a 1d numpy.ndarray
Args:
a (numpy.ndarray): Input array.
p (int): Value to be assigned.
"""
for i in range(a.shape[0]):
a[i] = p


@exp_dpex.kernel
def change_values_2d(x, v):
"""Assign values in a 2d dpnp.ndarray
Args:
x (dpnp.ndarray): Input array.
v (int): Value to be assigned.
"""
i = numba_dpex.get_global_id(0)
j = numba_dpex.get_global_id(1)
p = x[i, j] # getitem
p = v
x[i, j] = p # setitem
x[i, j] = v


def change_values_2d_func(a, p):
"""Assign values in a 2d numpy.ndarray
Args:
a (numpy.ndarray): Input array.
p (int): Value to be assigned.
"""
for i in range(a.shape[0]):
for j in range(a.shape[1]):
a[i, j] = p


@exp_dpex.kernel
def change_values_3d(x, v):
"""Assign values in a 3d dpnp.ndarray
Args:
x (dpnp.ndarray): Input array.
v (int): Value to be assigned.
"""
i = numba_dpex.get_global_id(0)
j = numba_dpex.get_global_id(1)
k = numba_dpex.get_global_id(2)
p = x[i, j, k] # getitem
p = v
x[i, j, k] = p # setitem
x[i, j, k] = v


def test_strided_dpnp_array_in_kernel():
def change_values_3d_func(a, p):
"""Assign values in a 3d numpy.ndarray
Args:
a (numpy.ndarray): Input array.
p (int): Value to be assigned.
"""
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
a[i, j, k] = p


@pytest.mark.parametrize("s", [1, 2, 3, 4, 5, 6, 7])
def test_1d_strided_dpnp_array_in_kernel(s):
"""
Tests if we can correctly handle a strided 1d dpnp array
inside dpex kernel.
"""
N = 1024
out = dpnp.arange(0, N * 2, dtype=dpnp.int64)
b = out[::2]
N = 256
k = -3

t = np.arange(0, N, dtype=dpnp.int64)
u = dpnp.asarray(t)

v = u[::s]
exp_dpex.call_kernel(change_values_1d, Range(v.shape[0]), v, k)

r = Range(N)
v = -3
exp_dpex.call_kernel(change_values_1d, r, b, v)
x = t[::s]
change_values_1d_func(x, k)

assert (dpnp.asnumpy(b) == v).all()
# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)


def test_multievel_strided_dpnp_array_in_kernel():
@pytest.mark.parametrize("s", [2, 3, 4, 5])
def test_multievel_1d_strided_dpnp_array_in_kernel(s):
"""
Tests if we can correctly handle a multilevel strided 1d dpnp array
inside dpex kernel.
"""
N = 128
out = dpnp.arange(0, N * 2, dtype=dpnp.int64)
v = -3
N = 256
k = -3

t = dpnp.arange(0, N, dtype=dpnp.int64)
u = dpnp.asarray(t)

v, x = u, t
while v.shape[0] > 1:
v = v[::s]
exp_dpex.call_kernel(change_values_1d, Range(v.shape[0]), v, k)

x = x[::s]
change_values_1d_func(x, k)

b = out
n = N
K = 7
for _ in range(K):
b = b[::2]
exp_dpex.call_kernel(change_values_1d, Range(n), b, v)
assert (dpnp.asnumpy(b) == v).all()
n = int(n / 2)
# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)


def test_multilevel_2d_strided_dpnp_array_in_kernel():
@pytest.mark.parametrize("s1", [2, 4, 6, 8])
@pytest.mark.parametrize("s2", [1, 3, 5, 7])
@pytest.mark.parametrize("order", ["C", "F"])
def test_2d_strided_dpnp_array_in_kernel(s1, s2, order):
"""
Tests if we can correctly handle a strided 2d dpnp array
inside dpex kernel.
"""
M, N = 13, 31
k = -3

t = np.arange(0, M * N, dtype=np.int64).reshape(M, N, order=order)
u = dpnp.asarray(t)

# check order, sanity check
assert get_order(u) == order

v = u[::s1, ::s2]
exp_dpex.call_kernel(change_values_2d, Range(*v.shape), v, k)

x = t[::s1, ::s2]
change_values_2d_func(x, k)

# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)


@pytest.mark.parametrize("s1", [2, 4, 6, 8])
@pytest.mark.parametrize("s2", [3, 5, 7, 9])
@pytest.mark.parametrize("order", ["C", "F"])
def test_multilevel_2d_strided_dpnp_array_in_kernel(s1, s2, order):
"""
Tests if we can correctly handle a multilevel strided 2d dpnp array
inside dpex kernel.
"""
N = 128
out, _ = dpnp.mgrid[0 : N * 2, 0 : N * 2] # noqa: E203
v = -3
M, N = 13, 31
k = -3

t = np.arange(0, M * N, dtype=np.int64).reshape(M, N, order=order)
u = dpnp.asarray(t)

# check order, sanity check
assert get_order(u) == order

b = out
n = N
K = 7
for _ in range(K):
b = b[::2, ::2]
exp_dpex.call_kernel(change_values_2d, Range(n, n), b, v)
assert (dpnp.asnumpy(b) == v).all()
n = int(n / 2)
v, x = u, t
while v.shape[0] > 1 and v.shape[1] > 1:
v = v[::s1, ::s2]
exp_dpex.call_kernel(change_values_2d, Range(*v.shape), v, k)

x = x[::s1, ::s2]
change_values_2d_func(x, k)

def test_multilevel_3d_strided_dpnp_array_in_kernel():
# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)


@pytest.mark.parametrize("s1", [1, 2, 3])
@pytest.mark.parametrize("s2", [2, 3, 4])
@pytest.mark.parametrize("s3", [3, 4, 5])
@pytest.mark.parametrize("order", ["C", "F"])
def test_3d_strided_dpnp_array_in_kernel(s1, s2, s3, order):
"""
Tests if we can correctly handle a strided 3d dpnp array
inside dpex kernel.
"""
M, N, K = 13, 31, 11
k = -3

t = np.arange(0, M * N * K, dtype=np.int64).reshape((M, N, K), order=order)
u = dpnp.asarray(t)

# check order, sanity check
assert get_order(u) == order

v = u[::s1, ::s2, ::s3]
exp_dpex.call_kernel(change_values_3d, Range(*v.shape), v, k)

x = t[::s1, ::s2, ::s3]
change_values_3d_func(x, k)

# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)


@pytest.mark.parametrize("s1", [2, 3, 4])
@pytest.mark.parametrize("s2", [3, 4, 5])
@pytest.mark.parametrize("s3", [4, 5, 6])
@pytest.mark.parametrize("order", ["C", "F"])
def test_multilevel_3d_strided_dpnp_array_in_kernel(s1, s2, s3, order):
"""
Tests if we can correctly handle a multilevel strided 3d dpnp array
inside dpex kernel.
"""
N = 128
out, _, _ = dpnp.mgrid[0 : N * 2, 0 : N * 2, 0 : N * 2] # noqa: E203
v = -3

b = out
n = N
K = 7
for _ in range(K):
b = b[::2, ::2, ::2]
exp_dpex.call_kernel(change_values_3d, Range(n, n, n), b, v)
assert (dpnp.asnumpy(b) == v).all()
n = int(n / 2)
M, N, K = 13, 31, 11
k = -3

t = np.arange(0, M * N * K, dtype=np.int64).reshape((M, N, K), order=order)
u = dpnp.asarray(t)

# check order, sanity check
assert get_order(u) == order

v, x = u, t
while v.shape[0] > 1 and v.shape[1] > 1 and v.shape[2] > 1:
v = v[::s1, ::s2, ::s3]
exp_dpex.call_kernel(change_values_3d, Range(*v.shape), v, k)

x = x[::s1, ::s2, ::s3]
change_values_3d_func(x, k)

# check the value of the array view
assert np.all(dpnp.asnumpy(v) == x)
# check the value of the original arrays
assert np.all(dpnp.asnumpy(u) == t)

0 comments on commit 6b483fd

Please sign in to comment.