Skip to content

Commit

Permalink
Create and use dpnp specific ufunc db
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb authored and ZzEeKkAa committed Jan 5, 2024
1 parent 71fd572 commit acfaa9d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 17 deletions.
6 changes: 6 additions & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from numba.core.imputils import Registry
from numba.core.target_extension import CPU, target_registry

from numba_dpex.dpnp_iface import dpnp_ufunc_db


class Dpex(CPU):
pass
Expand Down Expand Up @@ -57,3 +59,7 @@ def load_additional_registries(self):

# loading CPU specific registries
super().load_additional_registries()

# TODO: do we need it?
def get_ufunc_info(self, ufunc_key):
return dpnp_ufunc_db.get_ufunc_info(ufunc_key)
15 changes: 7 additions & 8 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,20 @@ def init(self):
self._target_data = ll.create_target_data(
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
)
# Override data model manager to SPIR model
import numba.cpython.unicode

# Override data model manager to SPIR model
self.data_model_manager = _init_data_model_manager()
self.extra_compile_options = dict()

import copy
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _lazy_init_dpnp_db

from numba.np.ufunc_db import _lazy_init_db
_lazy_init_dpnp_db()

_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db as ufunc_db
# we need to import it after, because before init it is None and
# variable is passed by value
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _dpnp_ufunc_db

self.ufunc_db = copy.deepcopy(ufunc_db)
self.cpu_context = cpu_target.target_context
self.ufunc_db = _dpnp_ufunc_db

def create_module(self, name):
return self._internal_codegen._create_empty_module(name)
Expand Down
52 changes: 43 additions & 9 deletions numba_dpex/dpnp_iface/dpnp_ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0


import copy

import dpnp
import numpy as np
from numba.core import types
Expand All @@ -11,23 +13,50 @@

from ..ocl import mathimpl

# A global instance of dpnp ufuncs that are supported by numba-dpex
_dpnp_ufunc_db = None


def _lazy_init_dpnp_db():
global _dpnp_ufunc_db

if _dpnp_ufunc_db is None:
_dpnp_ufunc_db = {}
_fill_ufunc_db_with_dpnp_ufuncs(_dpnp_ufunc_db)


def get_ufuncs():
"""obtain a list of supported ufuncs in the db"""
"""Returns the list of supported dpnp ufuncs in the _dpnp_ufunc_db"""

from numba.np.ufunc_db import _lazy_init_db
_lazy_init_dpnp_db()

_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db
return _dpnp_ufunc_db.keys()

_fill_ufunc_db_with_dpnp_ufuncs(_ufunc_db)

return _ufunc_db.keys()
def get_ufunc_info(ufunc_key):
"""get the lowering information for the ufunc with key ufunc_key.
The lowering information is a dictionary that maps from a numpy
loop string (as given by the ufunc types attribute) to a function
that handles code generation for a scalar version of the ufunc
(that is, generates the "per element" operation").
raises a KeyError if the ufunc is not in the ufunc_db
"""
_lazy_init_dpnp_db()
return _dpnp_ufunc_db[ufunc_key]


def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
"""Monkey patching dpnp for missing attributes."""
# FIXME: add more docstring
"""Populates the _dpnp_ufunc_db from Numba's NumPy ufunc_db"""

from numba.np.ufunc_db import _lazy_init_db

_lazy_init_db()

# we need to import it after, because before init it is None and
# variable is passed by value
from numba.np.ufunc_db import _ufunc_db

for ufuncop in dpnpdecl.supported_ufuncs:
if ufuncop == "erf":
Expand All @@ -52,7 +81,12 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
op.nargs = npop.nargs
op.types = npop.types
op.is_dpnp_ufunc = True
ufunc_db.update({op: ufunc_db[npop]})
cp = copy.copy(_ufunc_db[npop])
if "'divide'" in str(npop):
# TODO: why do we need to do it only for divide?
# https://github.com/IntelPython/numba-dpex/issues/1270
ufunc_db.update({npop: cp})
ufunc_db.update({op: cp})
for key in list(ufunc_db[op].keys()):
if (
"FF->" in key
Expand Down

0 comments on commit acfaa9d

Please sign in to comment.