Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for dpnp.full_like() #997

Merged
merged 1 commit into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 80 additions & 11 deletions numba_dpex/dpnp_iface/_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,71 @@ def codegen(context, builder, sig, llargs):
return sig, codegen


@intrinsic
def impl_dpnp_full(
ty_context,
ty_shape,
ty_fill_value,
ty_dtype,
ty_order,
ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
ty_retty_ref,
):
"""A numba "intrinsic" function to inject code for dpnp.full().

Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
ty_shape (numba.core.types.scalars.Integer or
numba.core.types.containers.UniTuple): Numba type for the shape
of the array.
ty_fill_value (numba.core.types.scalars): One of the Numba scalar
types.
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
dtype.
ty_order (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_like (numba.core.types.npytypes.Array): Numba type for array.
ty_device (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to
a type from numba, used when a type is passed as a value.

Returns:
tuple(numba.core.typing.templates.Signature, function): A tuple of
numba function signature type and a function object.
"""

ty_retty = ty_retty_ref.instance_type
signature = ty_retty(
ty_shape,
ty_fill_value,
ty_dtype,
ty_order,
ty_like,
ty_device,
ty_usm_type,
ty_sycl_queue,
ty_retty_ref,
)

def codegen(context, builder, sig, args):
fill_value = context.get_argument_value(builder, sig.args[1], args[1])
ary, _ = fill_arrayobj(
context, builder, sig, args, fill_value, is_like=False
)
return ary._getvalue()

return signature, codegen


@intrinsic
def impl_dpnp_empty_like(
ty_context,
Expand Down Expand Up @@ -490,33 +555,36 @@ def codegen(context, builder, sig, llargs):


@intrinsic
def impl_dpnp_full(
def impl_dpnp_full_like(
ty_context,
ty_shape,
ty_x1,
ty_fill_value,
ty_dtype,
ty_order,
ty_like,
ty_subok,
ty_shape,
ty_device,
ty_usm_type,
ty_sycl_queue,
ty_retty_ref,
):
"""A numba "intrinsic" function to inject code for dpnp.full().
"""A numba "intrinsic" function to inject code for dpnp.full_like().

Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
ty_shape (numba.core.types.scalars.Integer or
numba.core.types.containers.UniTuple): Numba type for the shape
of the array.
ty_x1 (numba.core.types.npytypes.Array): Numba type class for ndarray.
ty_fill_value (numba.core.types.scalars): One of the Numba scalar
types.
ty_dtype (numba.core.types.functions.NumberClass): Numba type for
dtype.
ty_order (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_like (numba.core.types.npytypes.Array): Numba type for array.
ty_subok (numba.core.types.scalars.Boolean): Numba type class for
subok.
ty_shape (numba.core.types.scalars.Integer or
numba.core.types.containers.UniTuple): Numba type for the shape
of the array. Not supported.
ty_device (numba.core.types.misc.UnicodeType): UnicodeType
from numba for strings.
ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType
Expand All @@ -533,11 +601,12 @@ def impl_dpnp_full(

ty_retty = ty_retty_ref.instance_type
signature = ty_retty(
ty_shape,
ty_x1,
ty_fill_value,
ty_dtype,
ty_order,
ty_like,
ty_subok,
ty_shape,
ty_device,
ty_usm_type,
ty_sycl_queue,
Expand All @@ -547,7 +616,7 @@ def impl_dpnp_full(
def codegen(context, builder, sig, args):
fill_value = context.get_argument_value(builder, sig.args[1], args[1])
ary, _ = fill_arrayobj(
context, builder, sig, args, fill_value, is_like=False
context, builder, sig, args, fill_value, is_like=True
)
return ary._getvalue()

Expand Down
113 changes: 113 additions & 0 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
impl_dpnp_empty,
impl_dpnp_empty_like,
impl_dpnp_full,
impl_dpnp_full_like,
impl_dpnp_ones,
impl_dpnp_ones_like,
impl_dpnp_zeros,
Expand Down Expand Up @@ -814,6 +815,118 @@ def impl(
)


@overload(dpnp.full_like, prefer_literal=True)
def ol_dpnp_full_like(
x1,
fill_value,
dtype=None,
order="C",
subok=None,
shape=None,
device=None,
usm_type=None,
sycl_queue=None,
):
"""Creates `usm_ndarray` from USM allocation initialized with values
specified by the `fill_value`.

This is an overloaded function implementation for dpnp.full_like().

Args:
x1 (numba.core.types.npytypes.Array): Input array from which to
derive the output array shape.
fill_value (numba.core.types.scalars): One of the
numba.core.types.scalar types for the value to
be filled.
dtype (numba.core.types.functions.NumberClass, optional):
Data type of the array. Can be typestring, a `numpy.dtype`
object, `numpy` char string, or a numpy scalar type.
Default: None.
order (str, optional): memory layout for the array "C" or "F".
Default: "C".
subok ('numba.core.types.scalars.BooleanLiteral', optional): A
boolean literal type for the `subok` parameter defined in
NumPy. If True, then the newly created array will use the
sub-class type of prototype, otherwise it will be a
base-class array. Defaults to False.
shape (numba.core.types.containers.UniTuple, optional): The shape
to override the shape of the given array. Not supported.
Default: `None`
device (numba.core.types.misc.StringLiteral, optional): array API
concept of device where the output array is created. `device`
can be `None`, a oneAPI filter selector string, an instance of
:class:`dpctl.SyclDevice` corresponding to a non-partitioned
SYCL device, an instance of :class:`dpctl.SyclQueue`, or a
`Device` object returnedby`dpctl.tensor.usm_array.device`.
Default: `None`.
usm_type (numba.core.types.misc.StringLiteral or str, optional):
The type of SYCL USM allocation for the output array.
Allowed values are "device"|"shared"|"host".
Default: `"device"`.
sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported.

Raises:
errors.TypingError: If couldn't parse input types to dpnp.full_like().
errors.TypingError: If shape is provided.

Returns:
function: Local function `impl_dpnp_full_like()`.
"""

if shape:
raise errors.TypingError(
"The parameter shape is not supported "
+ "inside overloaded dpnp.full_like() function."
)
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim is not None else 0
_dtype = _parse_dtype(dtype, data=x1)
_order = x1.layout if order is None else order
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
_device = (
_parse_device_filter_string(device) if device is not None else "unknown"
)
ret_ty = build_dpnp_ndarray(
_ndim,
layout=_order,
dtype=_dtype,
usm_type=_usm_type,
device=_device,
queue=sycl_queue,
)
if ret_ty:

def impl(
x1,
fill_value,
dtype=None,
order="C",
subok=None,
shape=None,
device=None,
usm_type=None,
sycl_queue=None,
):
return impl_dpnp_full_like(
x1,
fill_value,
_dtype,
_order,
subok,
shape,
_device,
_usm_type,
sycl_queue,
ret_ty,
)

return impl
else:
raise errors.TypingError(
"Cannot parse input types to "
+ f"function dpnp.full_like({x1}, {fill_value}, {dtype}, ...)."
)


@overload(dpnp.full, prefer_literal=True)
def ol_dpnp_full(
shape,
Expand Down
105 changes: 105 additions & 0 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_full_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for dpnp ndarray constructors."""

import math

import dpctl
import dpctl.tensor as dpt
import dpnp
import numpy
import pytest
from numba import errors

from numba_dpex import dpjit

shapes = [11, (3, 7)]
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
usm_types = ["device", "shared", "host"]
devices = ["cpu", "unknown"]
fill_values = [
7,
-7,
7.1,
-7.1,
math.pi,
math.e,
4294967295,
4294967295.0,
3.4028237e38,
]


@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("fill_value", fill_values)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("usm_type", usm_types)
@pytest.mark.parametrize("device", devices)
def test_dpnp_full_like(shape, fill_value, dtype, usm_type, device):
@dpjit
def func(a, v):
c = dpnp.full_like(a, v, dtype=dtype, usm_type=usm_type, device=device)
return c

if isinstance(shape, int):
NZ = numpy.random.rand(shape)
else:
NZ = numpy.random.rand(*shape)

try:
c = func(NZ, fill_value)
except Exception:
pytest.fail("Calling dpnp.full_like inside dpjit failed")

C = numpy.full_like(NZ, fill_value, dtype=dtype)

if len(c.shape) == 1:
assert c.shape[0] == NZ.shape[0]
else:
assert c.shape == NZ.shape

assert c.dtype == dtype
assert c.usm_type == usm_type
if device != "unknown":
assert (
c.sycl_device.filter_string
== dpctl.SyclDevice(device).filter_string
)
else:
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string

assert numpy.array_equal(dpt.asnumpy(c._array_obj), C)


def test_dpnp_full_like_exceptions():
@dpjit
def func1(a):
c = dpnp.full_like(a, shape=(3, 3))
return c

try:
func1(numpy.random.rand(5, 5))
except Exception as e:
assert isinstance(e, errors.TypingError)
assert (
"No implementation of function Function(<function full_like"
in str(e)
)

queue = dpctl.SyclQueue()

@dpjit
def func2(a):
c = dpnp.full_like(a, sycl_queue=queue)
return c

try:
func2(numpy.random.rand(5, 5))
except Exception as e:
assert isinstance(e, errors.TypingError)
assert (
"No implementation of function Function(<function full_like"
in str(e)
)