Skip to content

Commit

Permalink
Done refactoring empty, zeros, ones, full
Browse files Browse the repository at this point in the history
  • Loading branch information
khaled committed May 9, 2023
1 parent 6ad13ea commit 89e1169
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 85 deletions.
20 changes: 17 additions & 3 deletions numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
ndim,
layout="C",
dtype=None,
is_fill_value_float=False,
usm_type="device",
device=None,
queue=None,
Expand Down Expand Up @@ -66,9 +67,22 @@ def __init__(
self.device = self.queue.sycl_device.filter_string

if not dtype:
dummy_tensor = dpctl.tensor.empty(
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
)
if is_fill_value_float:
dummy_tensor = dpctl.tensor.empty(
1,
dtype=dpctl.tensor.float64,
order=layout,
usm_type=usm_type,
sycl_queue=self.queue,
)
else:
dummy_tensor = dpctl.tensor.empty(
1,
dtype=dpctl.tensor.int64,
order=layout,
usm_type=usm_type,
sycl_queue=self.queue,
)
# convert dpnp type to numba/numpy type
_dtype = dummy_tensor.dtype
self.dtype = from_dtype(_dtype)
Expand Down
13 changes: 13 additions & 0 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dpnp
from numba import errors, types
from numba.core.types import scalars
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.extending import overload
Expand Down Expand Up @@ -93,6 +94,9 @@ def _parse_usm_type(usm_type):
raise errors.NumbaValueError(msg)
return usm_type_str
elif isinstance(usm_type, str):
if usm_type not in ["shared", "device", "host"]:
msg = f"Invalid usm_type specified: '{usm_type}'"
raise errors.NumbaValueError(msg)
return usm_type
else:
raise TypeError(
Expand Down Expand Up @@ -149,6 +153,7 @@ def build_dpnp_ndarray(
ndim,
layout="C",
dtype=None,
is_fill_value_float=False,
usm_type="device",
device=None,
sycl_queue=None,
Expand All @@ -162,6 +167,8 @@ def build_dpnp_ndarray(
Data type of the array. Can be typestring, a `numpy.dtype`
object, `numpy` char string, or a numpy scalar type.
Default: None.
is_fill_value_float (bool): Specify if the fill value is floating
point.
usm_type (numba.core.types.misc.StringLiteral, optional):
The type of SYCL USM allocation for the output array.
Allowed values are "device"|"shared"|"host".
Expand Down Expand Up @@ -197,6 +204,7 @@ def build_dpnp_ndarray(
ndim=ndim,
layout=layout,
dtype=dtype,
is_fill_value_float=is_fill_value_float,
usm_type=usm_type,
device=device,
queue=sycl_queue,
Expand Down Expand Up @@ -279,6 +287,7 @@ def ol_dpnp_empty(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -383,6 +392,7 @@ def ol_dpnp_zeros(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -487,6 +497,7 @@ def ol_dpnp_ones(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=True,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down Expand Up @@ -585,6 +596,7 @@ def ol_dpnp_full(

_ndim = _ty_parse_shape(shape)
_dtype = _parse_dtype(dtype)
_is_fill_value_float = isinstance(fill_value, scalars.Float)
_layout = _parse_layout(order)
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
_device = _parse_device_filter_string(device) if device else None
Expand All @@ -595,6 +607,7 @@ def ol_dpnp_full(
_ndim,
layout=_layout,
dtype=_dtype,
is_fill_value_float=_is_fill_value_float,
usm_type=_usm_type,
device=_device,
sycl_queue=_sycl_queue,
Expand Down
74 changes: 42 additions & 32 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,37 @@
usm_types = ["device", "shared", "host"]


@pytest.mark.parametrize("shape", shapes)
def test_dpnp_empty_default(shape):
"""Test dpnp.empty() with default parameters inside dpjit."""

@dpjit
def func(shape):
c = dpnp.empty(shape)
return c

try:
c = func(shape)
except Exception:
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
else:
assert c.shape == shape

dummy = dpnp.empty(shape)

assert c.dtype == dummy.dtype
assert c.usm_type == dummy.usm_type
assert c.sycl_device == dummy.sycl_device


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_from_device(shape, dtype, usm_type):
""" "Use device only in dpnp.emtpy() inside dpjit."""
device = dpctl.SyclDevice().filter_string

@dpjit
Expand All @@ -30,7 +57,7 @@ def func(shape):
try:
c = func(shape)
except Exception:
pytest.fail("Calling dpnp.empty inside dpjit failed")
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
Expand All @@ -46,6 +73,8 @@ def func(shape):
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_empty_from_queue(shape, dtype, usm_type):
""" "Use queue only in dpnp.emtpy() inside dpjit."""

@dpjit
def func(shape, queue):
c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
Expand All @@ -56,7 +85,7 @@ def func(shape, queue):
try:
c = func(shape, queue)
except Exception:
pytest.fail("Calling dpnp.empty inside dpjit failed")
pytest.fail("Calling dpnp.empty() inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == shape
Expand All @@ -65,45 +94,26 @@ def func(shape, queue):

assert c.dtype == dtype
assert c.usm_type == usm_type
assert c.sycl_device.filter_string == queue.sycl_device.filter_string
assert c.sycl_device == queue.sycl_device

if c.sycl_queue != queue:
pytest.xfail("Returned queue does not have the queue passed to empty.")


@pytest.mark.parametrize("shape", shapes)
def test_dpnp_empty_default(shape):
@dpjit
def func(shape):
c = dpnp.empty(shape)
return c

try:
c = func(shape)
except Exception:
pytest.fail("Calling dpnp.empty inside dpjit failed")

if len(c.shape) == 1:
assert c.shape[0] == shape
else:
assert c.shape == shape

dummy_tensor = dpctl.tensor.empty(shape)

assert c.dtype == dummy_tensor.dtype
assert c.usm_type == dummy_tensor.usm_type
assert c.sycl_device == dummy_tensor.sycl_device
pytest.xfail(
"Returned queue does not have the queue passed to the dpnp function."
)


def test_dpnp_empty_exceptions():
queue = dpctl.SyclQueue()
"""Test if exception is raised when both queue and device are specified."""
device = dpctl.SyclDevice().filter_string

@dpjit
def func2(q):
c = dpnp.empty(10, sycl_queue=q, device="cpu")
def func(shape, queue):
c = dpnp.empty(shape, sycl_queue=queue, device=device)
return c

queue = dpctl.SyclQueue()

try:
func2(queue)
func(10, queue)
except Exception as e:
assert isinstance(e, errors.TypingError)
8 changes: 4 additions & 4 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def func(a):
c = dpnp.empty_like(a, dtype=dtype, usm_type=usm_type, device=device)
return c

NZ = dpnp.empty(shape)
NZ = dpnp.random.random(shape)

try:
c = func(NZ)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed")
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == NZ.shape[0]
Expand Down Expand Up @@ -62,7 +62,7 @@ def func(a, q):
try:
c = func(NZ, queue)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed")
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")

if len(c.shape) == 1:
assert c.shape[0] == NZ.shape[0]
Expand Down Expand Up @@ -92,7 +92,7 @@ def func(arr):
try:
c = func(arr)
except Exception:
pytest.fail("Calling dpnp.empty_like inside dpjit failed")
pytest.fail("Calling dpnp.empty_like inside dpjit failed.")

assert c.shape == arr.shape
assert c.dtype == arr.dtype
Expand Down
Loading

0 comments on commit 89e1169

Please sign in to comment.