From 44b81c06b5dd642e8c2e2cea75d000f287164437 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Mon, 17 Apr 2023 10:43:38 -0400 Subject: [PATCH] Add lower function --- numba_dpex/__init__.py | 4 ++-- numba_dpex/dpnp_iface/dpnpimpl.py | 20 +++++++++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/numba_dpex/__init__.py b/numba_dpex/__init__.py index d043aea5ce..12f1d76919 100644 --- a/numba_dpex/__init__.py +++ b/numba_dpex/__init__.py @@ -77,9 +77,9 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]: numba_version = parse_sem_version(numba.__version__) -if numba_version < (0, 56, 4): +if numba_version < (0, 57, 0): logging.warning( - "numba_dpex needs numba 0.56.4, using " + "numba_dpex needs numba 0.57.0, using " f"numba={numba_version} may cause unexpected behavior" ) diff --git a/numba_dpex/dpnp_iface/dpnpimpl.py b/numba_dpex/dpnp_iface/dpnpimpl.py index 7873bbc5b5..b8c60830a6 100644 --- a/numba_dpex/dpnp_iface/dpnpimpl.py +++ b/numba_dpex/dpnp_iface/dpnpimpl.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp +from numba.core.imputils import Registry from numba.np import npyimpl from numba_dpex.core.typing.dpnpdecl import _unsupported @@ -10,11 +11,16 @@ def _register_dpnp_ufuncs(): + registry = Registry("npyimpl") + lower = registry.lower + kernels = {} # NOTE: Assuming ufunc implementation for the CPUContext. for ufunc in dpnp_ufunc_db.get_ufuncs(): kernels[ufunc] = npyimpl.register_ufunc_kernel( - ufunc, npyimpl._ufunc_db_function(ufunc) + ufunc, + npyimpl._ufunc_db_function(ufunc), + lower, ) for _op_map in ( @@ -27,9 +33,13 @@ def _register_dpnp_ufuncs(): ufunc = getattr(dpnp, ufunc_name) kernel = kernels[ufunc] if ufunc.nin == 1: - npyimpl.register_unary_operator_kernel(operator, ufunc, kernel) + npyimpl.register_unary_operator_kernel( + operator, ufunc, kernel, lower + ) elif ufunc.nin == 2: - npyimpl.register_binary_operator_kernel(operator, ufunc, kernel) + npyimpl.register_binary_operator_kernel( + operator, ufunc, kernel, lower + ) else: raise RuntimeError( "There shouldn't be any non-unary or binary operators" @@ -43,11 +53,11 @@ def _register_dpnp_ufuncs(): kernel = kernels[ufunc] if ufunc.nin == 1: npyimpl.register_unary_operator_kernel( - operator, ufunc, kernel, inplace=True + operator, ufunc, kernel, lower, inplace=True ) elif ufunc.nin == 2: npyimpl.register_binary_operator_kernel( - operator, ufunc, kernel, inplace=True + operator, ufunc, kernel, lower, inplace=True ) else: raise RuntimeError(