Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create dpnp specific ufunc db #1267

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from numba.core.imputils import Registry
from numba.core.target_extension import CPU, target_registry

from numba_dpex.dpnp_iface import dpnp_ufunc_db


class Dpex(CPU):
pass
Expand Down Expand Up @@ -57,3 +59,7 @@ def load_additional_registries(self):

# loading CPU specific registries
super().load_additional_registries()

# TODO: do we need it?
def get_ufunc_info(self, ufunc_key):
return dpnp_ufunc_db.get_ufunc_info(ufunc_key)
15 changes: 7 additions & 8 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,20 @@ def init(self):
self._target_data = ll.create_target_data(
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
)
# Override data model manager to SPIR model
import numba.cpython.unicode

# Override data model manager to SPIR model
self.data_model_manager = _init_data_model_manager()
self.extra_compile_options = dict()

import copy
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _lazy_init_dpnp_db

from numba.np.ufunc_db import _lazy_init_db
_lazy_init_dpnp_db()

_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db as ufunc_db
# we need to import it after, because before init it is None and
# variable is passed by value
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _dpnp_ufunc_db

self.ufunc_db = copy.deepcopy(ufunc_db)
self.cpu_context = cpu_target.target_context
self.ufunc_db = _dpnp_ufunc_db

def create_module(self, name):
return self._internal_codegen._create_empty_module(name)
Expand Down
52 changes: 43 additions & 9 deletions numba_dpex/dpnp_iface/dpnp_ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0


import copy

import dpnp
import numpy as np
from numba.core import types
Expand All @@ -11,23 +13,50 @@

from ..ocl import mathimpl

# A global instance of dpnp ufuncs that are supported by numba-dpex
_dpnp_ufunc_db = None


def _lazy_init_dpnp_db():
global _dpnp_ufunc_db

if _dpnp_ufunc_db is None:
_dpnp_ufunc_db = {}
_fill_ufunc_db_with_dpnp_ufuncs(_dpnp_ufunc_db)


def get_ufuncs():
"""obtain a list of supported ufuncs in the db"""
"""Returns the list of supported dpnp ufuncs in the _dpnp_ufunc_db"""

from numba.np.ufunc_db import _lazy_init_db
_lazy_init_dpnp_db()

_lazy_init_db()
from numba.np.ufunc_db import _ufunc_db
return _dpnp_ufunc_db.keys()

_fill_ufunc_db_with_dpnp_ufuncs(_ufunc_db)

return _ufunc_db.keys()
def get_ufunc_info(ufunc_key):
"""get the lowering information for the ufunc with key ufunc_key.

The lowering information is a dictionary that maps from a numpy
loop string (as given by the ufunc types attribute) to a function
that handles code generation for a scalar version of the ufunc
(that is, generates the "per element" operation").

raises a KeyError if the ufunc is not in the ufunc_db
"""
_lazy_init_dpnp_db()
return _dpnp_ufunc_db[ufunc_key]


def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
"""Monkey patching dpnp for missing attributes."""
# FIXME: add more docstring
"""Populates the _dpnp_ufunc_db from Numba's NumPy ufunc_db"""

from numba.np.ufunc_db import _lazy_init_db

_lazy_init_db()

# we need to import it after, because before init it is None and
# variable is passed by value
from numba.np.ufunc_db import _ufunc_db

for ufuncop in dpnpdecl.supported_ufuncs:
if ufuncop == "erf":
Expand All @@ -52,7 +81,12 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
op.nargs = npop.nargs
op.types = npop.types
op.is_dpnp_ufunc = True
ufunc_db.update({op: ufunc_db[npop]})
cp = copy.copy(_ufunc_db[npop])
if "'divide'" in str(npop):
# TODO: why do we need to do it only for divide?
# https://github.com/IntelPython/numba-dpex/issues/1270
ufunc_db.update({npop: cp})
ufunc_db.update({op: cp})
for key in list(ufunc_db[op].keys()):
if (
"FF->" in key
Expand Down
40 changes: 40 additions & 0 deletions numba_dpex/tests/test_dpex_use_alongside_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
This module contains tests to ensure that numba.njit works with numpy after
importing numba_dpex. Aka lazy testing if we break numba's default behavior.
"""

import numba as nb
import numpy as np

import numba_dpex


@nb.njit
def add1(a):
return a + 1


def add_py(a, b):
return np.add(a, b)


add_jit = nb.njit(add_py)


def test_add1():
a = np.asarray([1j], dtype=np.complex64)
assert np.array_equal(add1(a), np.asarray([1 + 1j], dtype=np.complex64))


def test_add_py():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_py(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))


def test_add_jit():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_jit(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))
Loading