diff --git a/numba_dpex/core/targets/dpjit_target.py b/numba_dpex/core/targets/dpjit_target.py index 883a00bc5e..874b66f1b3 100644 --- a/numba_dpex/core/targets/dpjit_target.py +++ b/numba_dpex/core/targets/dpjit_target.py @@ -14,6 +14,8 @@ from numba.core.imputils import Registry from numba.core.target_extension import CPU, target_registry +from numba_dpex.dpnp_iface import dpnp_ufunc_db + class Dpex(CPU): pass @@ -57,3 +59,7 @@ def load_additional_registries(self): # loading CPU specific registries super().load_additional_registries() + + # TODO: do we need it? + def get_ufunc_info(self, ufunc_key): + return dpnp_ufunc_db.get_ufunc_info(ufunc_key) diff --git a/numba_dpex/core/targets/kernel_target.py b/numba_dpex/core/targets/kernel_target.py index dacde5c772..4968e8c86c 100644 --- a/numba_dpex/core/targets/kernel_target.py +++ b/numba_dpex/core/targets/kernel_target.py @@ -251,21 +251,20 @@ def init(self): self._target_data = ll.create_target_data( codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS] ) - # Override data model manager to SPIR model - import numba.cpython.unicode + # Override data model manager to SPIR model self.data_model_manager = _init_data_model_manager() self.extra_compile_options = dict() - import copy + from numba_dpex.dpnp_iface.dpnp_ufunc_db import _lazy_init_dpnp_db - from numba.np.ufunc_db import _lazy_init_db + _lazy_init_dpnp_db() - _lazy_init_db() - from numba.np.ufunc_db import _ufunc_db as ufunc_db + # we need to import it after, because before init it is None and + # variable is passed by value + from numba_dpex.dpnp_iface.dpnp_ufunc_db import _dpnp_ufunc_db - self.ufunc_db = copy.deepcopy(ufunc_db) - self.cpu_context = cpu_target.target_context + self.ufunc_db = _dpnp_ufunc_db def create_module(self, name): return self._internal_codegen._create_empty_module(name) diff --git a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py index 6a6d3d6d50..80a246fe86 100644 --- a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py +++ b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 +import copy + import dpnp import numpy as np from numba.core import types @@ -11,23 +13,50 @@ from ..ocl import mathimpl +# A global instance of dpnp ufuncs that are supported by numba-dpex +_dpnp_ufunc_db = None + + +def _lazy_init_dpnp_db(): + global _dpnp_ufunc_db + + if _dpnp_ufunc_db is None: + _dpnp_ufunc_db = {} + _fill_ufunc_db_with_dpnp_ufuncs(_dpnp_ufunc_db) + def get_ufuncs(): - """obtain a list of supported ufuncs in the db""" + """Returns the list of supported dpnp ufuncs in the _dpnp_ufunc_db""" - from numba.np.ufunc_db import _lazy_init_db + _lazy_init_dpnp_db() - _lazy_init_db() - from numba.np.ufunc_db import _ufunc_db + return _dpnp_ufunc_db.keys() - _fill_ufunc_db_with_dpnp_ufuncs(_ufunc_db) - return _ufunc_db.keys() +def get_ufunc_info(ufunc_key): + """get the lowering information for the ufunc with key ufunc_key. + + The lowering information is a dictionary that maps from a numpy + loop string (as given by the ufunc types attribute) to a function + that handles code generation for a scalar version of the ufunc + (that is, generates the "per element" operation"). + + raises a KeyError if the ufunc is not in the ufunc_db + """ + _lazy_init_dpnp_db() + return _dpnp_ufunc_db[ufunc_key] def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): - """Monkey patching dpnp for missing attributes.""" - # FIXME: add more docstring + """Populates the _dpnp_ufunc_db from Numba's NumPy ufunc_db""" + + from numba.np.ufunc_db import _lazy_init_db + + _lazy_init_db() + + # we need to import it after, because before init it is None and + # variable is passed by value + from numba.np.ufunc_db import _ufunc_db for ufuncop in dpnpdecl.supported_ufuncs: if ufuncop == "erf": @@ -52,7 +81,12 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): op.nargs = npop.nargs op.types = npop.types op.is_dpnp_ufunc = True - ufunc_db.update({op: ufunc_db[npop]}) + cp = copy.copy(_ufunc_db[npop]) + if "'divide'" in str(npop): + # TODO: why do we need to do it only for divide? + # https://github.com/IntelPython/numba-dpex/issues/1270 + ufunc_db.update({npop: cp}) + ufunc_db.update({op: cp}) for key in list(ufunc_db[op].keys()): if ( "FF->" in key