Skip to content

Commit

Permalink
Add unit test cases for dpnp slicing and stride calc
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Apr 28, 2023
1 parent 7a71627 commit 4f72382
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import dpctl
import pytest

from numba_dpex import dpjit

Expand Down
27 changes: 0 additions & 27 deletions numba_dpex/tests/core/types/DpnpNdArray/test_boxing.py

This file was deleted.

47 changes: 47 additions & 0 deletions numba_dpex/tests/core/types/DpnpNdArray/test_boxing_unboxing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests for boxing for dpnp.ndarray
"""

import dpnp

from numba_dpex import dpjit


def test_boxing_unboxing():
"""Tests basic boxing and unboxing of a dpnp.ndarray object.
Checks if we can pass in and return a dpctl.ndarray object to and
from a dpjit decorated function.
"""

@dpjit
def func(a):
return a

a = dpnp.empty(10)
try:
b = func(a)
except:
assert False, "Failure during unbox/box of dpnp.ndarray"

assert a.shape == b.shape
assert a.device == b.device
assert a.strides == b.strides
assert a.dtype == b.dtype


def test_stride_calc_at_unboxing():
"""Tests if strides were correctly computed during unboxing."""

def _tester(a):
return a.strides

b = dpnp.empty((4, 16, 4))
strides = dpjit(_tester)(b)

# Numba computes strides as bytes
assert list(strides) == [512, 32, 8]
62 changes: 62 additions & 0 deletions numba_dpex/tests/dpjit_tests/test_slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Tests for boxing for dpnp.ndarray
"""

import dpnp
import numpy

from numba_dpex import dpjit


def test_1d_slicing():
"""Tests if dpjit properly computes strides and returns them to Python."""

def _tester(a):
return a[1:5]

a = dpnp.arange(10)
b = dpnp.asnumpy(dpjit(_tester)(a))

na = numpy.arange(10)
nb = _tester(na)

assert (b == nb).all()


def test_1d_slicing2():
"""Tests if dpjit properly computes strides and returns them to Python."""

def _tester(a):
b = a[1:4]
a[6:9] = b

a = dpnp.arange(10)
b = dpnp.asnumpy(dpjit(_tester)(a))

na = numpy.arange(10)
nb = _tester(na)

assert (b == nb).all()


def test_multidim_slicing():
"""Tests if dpjit properly slices strides and returns them to Python."""

def _tester(a, b):
b[:, :, 0] = a

a = dpnp.arange(64)
a = a.reshape(4, 16)
b = dpnp.empty((4, 16, 4))
dpjit(_tester)(a, b)

na = numpy.arange(64)
na = na.reshape(4, 16)
nb = numpy.empty((4, 16, 4))
_tester(na, nb)

assert numpy.allclose(nb, dpnp.asnumpy(b))

0 comments on commit 4f72382

Please sign in to comment.