Skip to content

Commit

Permalink
Refactored all dpnp functions overload
Browse files Browse the repository at this point in the history
  • Loading branch information
khaled committed May 9, 2023
1 parent 89e1169 commit 9722585
Show file tree
Hide file tree
Showing 9 changed files with 623 additions and 192 deletions.
40 changes: 29 additions & 11 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dpnp
from numba import errors, types
from numba.core.types import scalars
from numba.core.types.containers import UniTuple
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.extending import overload
Expand All @@ -28,7 +29,20 @@
# =========================================================================


def _parse_dtype(dtype, data=None):
def _parse_dim(x1):
if hasattr(x1, "ndim") and x1.ndim:
return x1.ndim
elif isinstance(x1, scalars.Integer):
r = 1
return r
elif isinstance(x1, UniTuple):
r = len(x1)
return r
else:
return 0


def _parse_dtype(dtype):
"""Resolve dtype parameter.
Resolves the dtype parameter based on the given value
Expand All @@ -44,9 +58,8 @@ class for nd-arrays. Defaults to None.
numba.core.types.functions.NumberClass: Resolved numba type
class for number classes.
"""

_dtype = None
if data and isinstance(data, types.Array):
_dtype = data.dtype
if not is_nonelike(dtype):
_dtype = _ty_parse_dtype(dtype)
return _dtype
Expand All @@ -60,6 +73,9 @@ def _parse_layout(layout):
raise errors.NumbaValueError(msg)
return layout_type_str
elif isinstance(layout, str):
if layout not in ["C", "F", "A"]:
msg = f"Invalid layout specified: '{layout}'"
raise errors.NumbaValueError(msg)
return layout
else:
raise TypeError(
Expand Down Expand Up @@ -711,8 +727,8 @@ def ol_dpnp_empty_like(
+ "inside overloaded dpnp.empty_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -824,8 +840,8 @@ def ol_dpnp_zeros_like(
+ "inside overloaded dpnp.zeros_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -936,8 +952,8 @@ def ol_dpnp_ones_like(
+ "inside overloaded dpnp.ones_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand Down Expand Up @@ -1053,8 +1069,9 @@ def ol_dpnp_full_like(
+ "inside overloaded dpnp.full_like() function."
)

_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
_dtype = _parse_dtype(dtype, data=x1)
_ndim = _parse_dim(x1)
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
_is_fill_value_float = isinstance(fill_value, scalars.Float)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand All @@ -1064,6 +1081,7 @@ def ol_dpnp_full_like(
_ndim,
layout=_order,
dtype=_dtype,
is_fill_value_float=_is_fill_value_float,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down
15 changes: 15 additions & 0 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def func(shape):
assert c.dtype == dummy.dtype
assert c.usm_type == dummy.usm_type
assert c.sycl_device == dummy.sycl_device
assert c.sycl_queue == dummy.sycl_queue
if c.sycl_queue != dummy.sycl_queue:
pytest.xfail(
"Returned queue does not have the queue in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
dummy.sycl_device
)


@pytest.mark.parametrize("shape", shapes)
Expand Down Expand Up @@ -67,6 +75,12 @@ def func(shape):
assert c.dtype == dtype
assert c.usm_type == usm_type
assert c.sycl_device.filter_string == device
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
device
):
pytest.xfail(
"Returned queue does not have the queue cached against the device."
)


@pytest.mark.parametrize("shape", shapes)
Expand Down Expand Up @@ -117,3 +131,4 @@ def func(shape, queue):
func(10, queue)
except Exception as e:
assert isinstance(e, errors.TypingError)
assert "`device` and `sycl_queue` are exclusive keywords" in str(e)
187 changes: 109 additions & 78 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,139 +19,170 @@


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_like_from_device(shape, dtype, usm_type):
device = dpctl.SyclDevice().filter_string
def test_dpnp_empty_like_default(shape):
"""Test dpnp.empty_like() with default parameters inside dpjit."""

@dpjit
def func(a):
c = dpnp.empty_like(a, dtype=dtype, usm_type=usm_type, device=device)
return c

NZ = dpnp.random.random(shape)
def func(x):
y = dpnp.empty_like(x)
return y

try:
c = func(NZ)
a = dpnp.ones(shape)
c = func(a)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")
pytest.fail("Calling dpnp.empty_like() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == NZ.shape[0]
assert c.shape[0] == a.shape[0]
else:
assert c.shape == NZ.shape
assert c.shape == a.shape

assert c.dtype == dtype
assert c.usm_type == usm_type
assert c.sycl_device.filter_string == device
assert c.sycl_queue == dpctl.get_device_cached_queue(device)
dummy = dpnp.empty_like(a)

assert c.dtype == dummy.dtype
assert c.usm_type == dummy.usm_type
assert c.sycl_device == dummy.sycl_device
if c.sycl_queue != dummy.sycl_queue:
pytest.xfail(
"Returned queue does not have the queue in the dummy array."
)
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
dummy.sycl_device
)


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_like_from_queue(shape, dtype, usm_type):
@dpjit
def func(a, q):
c = dpnp.empty_like(a, dtype=dtype, usm_type=usm_type, sycl_queue=q)
return c
def test_dpnp_empty_like_from_device(shape, dtype, usm_type):
""" "Use device only in dpnp.emtpy)like() inside dpjit."""
device = dpctl.SyclDevice().filter_string

NZ = dpnp.empty(shape)
queue = dpctl.SyclQueue()
@dpjit
def func(x):
y = dpnp.empty_like(x, dtype=dtype, usm_type=usm_type, device=device)
return y

try:
c = func(NZ, queue)
a = dpnp.ones(shape, dtype=dtype, usm_type=usm_type, device=device)
c = func(a)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")
pytest.fail("Calling dpnp.empty_like() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == NZ.shape[0]
assert c.shape[0] == a.shape[0]
else:
assert c.shape == NZ.shape
assert c.shape == a.shape

assert c.dtype == dtype
assert c.usm_type == usm_type
assert c.sycl_queue == NZ.sycl_queue
assert c.sycl_queue != queue
assert c.dtype == a.dtype
assert c.usm_type == a.usm_type
assert c.sycl_device.filter_string == device
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
device
):
pytest.xfail(
"Returned queue does not have the queue cached against the device."
)


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_like_default(
shape,
usm_type,
dtype,
):
def test_dpnp_empty_like_from_queue(shape, dtype, usm_type):
""" "Use queue only in dpnp.emtpy_like() inside dpjit."""

@dpjit
def func(arr):
c = dpnp.empty_like(arr)
return c
def func(x, queue):
y = dpnp.empty_like(x, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
return y

queue = dpctl.SyclQueue()

arr = dpnp.empty(shape=shape, usm_type=usm_type, dtype=dtype)
try:
c = func(arr)
a = dpnp.ones(shape, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
c = func(a, queue)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")
pytest.fail("Calling dpnp.empty_like() inside dpjit failed.")

assert c.shape == arr.shape
assert c.dtype == arr.dtype
if c.usm_type != arr.usm_type:
pytest.xfail("Usm type not correctly populated from passed in array")
assert c.sycl_queue == arr.sycl_queue
if len(c.shape) == 1:
assert c.shape[0] == a.shape[0]
else:
assert c.shape == a.shape

assert c.dtype == a.dtype
assert c.usm_type == a.usm_type
assert c.sycl_device == queue.sycl_device

@pytest.mark.xfail
def test_dpnp_empty_like_from_numpy(shape):
@dpjit
def func(arr):
c = dpnp.empty_like(arr)
return c
if c.sycl_queue != queue:
pytest.xfail(
"Returned queue does not have the queue passed to the dpnp function."
)

arr = numpy.empty(10)
with pytest.raises(Exception):
func(arr)

def test_dpnp_empty_like_exceptions():
"""Test if exception is raised when both queue and device are specified."""

@pytest.mark.xfail
def test_dpnp_empty_like_from_freevar_queue():
queue = dpctl.SyclQueue()
device = dpctl.SyclDevice().filter_string

@dpjit
def func():
c = dpnp.empty_like(10, sycl_queue=queue)
return c
def func1(x, queue):
y = dpnp.empty_like(x, sycl_queue=queue, device=device)
return y

try:
func()
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed")
queue = dpctl.SyclQueue()

try:
a = dpnp.ones(10)
func1(a, queue)
except Exception as e:
assert isinstance(e, errors.TypingError)
assert "`device` and `sycl_queue` are exclusive keywords" in str(e)

def test_dpnp_empty_like_exceptions():
@dpjit
def func1(a):
c = dpnp.empty_like(a, shape=(3, 3))
return c
def func2(x):
y = dpnp.empty_like(x, shape=(3, 3))
return y

try:
func1(dpnp.empty((5, 5)))
func2(a)
except Exception as e:
assert isinstance(e, errors.TypingError)
assert (
"No implementation of function Function(<function empty_like"
in str(e)
)

queue = dpctl.SyclQueue()

@pytest.mark.xfail
def test_dpnp_empty_like_from_numpy():
"""Test if dpnp works with numpy array (it shouldn't)"""

@dpjit
def func2(a, q):
c = dpnp.empty_like(a, sycl_queue=q, device="cpu")
return c
def func(x):
y = dpnp.empty_like(x)
return y

a = numpy.empty(10)

with pytest.raises(Exception):
func(a)


@pytest.mark.parametrize("shape", shapes)
def test_dpnp_empty_like_from_scalar(shape):
"""Test if works with scalar argument in place of an array"""

@dpjit
def func(shape):
x = dpnp.empty_like(shape)
return x

try:
func2(dpnp.empty((5, 5)), queue)
func(shape)
except Exception as e:
assert isinstance(e, errors.TypingError)
assert "`device` and `sycl_queue` are exclusive keywords" in str(e)
assert (
"No implementation of function Function(<function empty_like"
in str(e)
)
Loading

0 comments on commit 9722585

Please sign in to comment.