Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements types property for elementwise functions #1361

Merged
merged 2 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_to_device_supported_dtype,
Expand All @@ -44,6 +45,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
self.__name__ = "UnaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.types_ = None
self.unary_fn_ = unary_dp_impl_fn
self.__doc__ = docs

Expand All @@ -53,6 +55,18 @@ def __str__(self):
def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

@property
def types(self):
types = self.types_
if not types:
types = []
for dt1 in _all_data_types(True, True):
dt2 = self.result_type_resolver_fn_(dt1)
if dt2:
types.append(f"{dt1.char}->{dt2.char}")
self.types_ = types
return types

def __call__(self, x, out=None, order="K"):
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
Expand Down Expand Up @@ -363,6 +377,7 @@ def __init__(
self.__name__ = "BinaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.types_ = None
self.binary_fn_ = binary_dp_impl_fn
self.binary_inplace_fn_ = binary_inplace_fn
self.__doc__ = docs
Expand All @@ -377,6 +392,20 @@ def __str__(self):
def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

@property
def types(self):
types = self.types_
if not types:
types = []
_all_dtypes = _all_data_types(True, True)
for dt1 in _all_dtypes:
for dt2 in _all_dtypes:
dt3 = self.result_type_resolver_fn_(dt1, dt2)
if dt3:
types.append(f"{dt1.char}{dt2.char}->{dt3.char}")
self.types_ = types
return types

def __call__(self, o1, o2, out=None, order="K"):
if order not in ["K", "C", "F", "A"]:
order = "K"
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/elementwise/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def test_abs_usm_type(usm_type):
assert np.allclose(dpt.asnumpy(Y), expected_Y)


def test_abs_types_prop():
types = dpt.abs.types_
assert types is None
types = dpt.abs.types
assert isinstance(types, list)
assert len(types) > 0
assert types == dpt.abs.types_


@pytest.mark.parametrize("dtype", _all_dtypes[1:])
def test_abs_order(dtype):
q = get_queue_or_skip()
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/elementwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,15 @@ def __sycl_usm_array_interface__(self):
dpt.add(a, c)


def test_add_types_property():
types = dpt.add.types_
assert types is None
types = dpt.add.types
assert isinstance(types, list)
assert len(types) > 0
assert types == dpt.add.types_


def test_add_errors():
get_queue_or_skip()
try:
Expand Down