Skip to content

Commit

Permalink
Fixes dpnp ufincs
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed May 19, 2023
1 parent ffbd423 commit 0bf1547
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 46 deletions.
74 changes: 37 additions & 37 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from functools import cached_property

import numpy as np
import dpnp
from llvmlite import binding as ll
from llvmlite import ir as llvmir
from numba import typeof
Expand Down Expand Up @@ -298,37 +298,36 @@ def init(self):
def create_module(self, name):
return self._internal_codegen._create_empty_module(name)

def replace_numpy_ufunc_with_opencl_supported_functions(self):
def replace_dpnp_ufunc_with_ocl_intrinsics(self):
from numba_dpex.ocl.mathimpl import lower_ocl_impl, sig_mapper

ufuncs = [
("fabs", np.fabs),
("exp", np.exp),
("log", np.log),
("log10", np.log10),
("expm1", np.expm1),
("log1p", np.log1p),
("sqrt", np.sqrt),
("sin", np.sin),
("cos", np.cos),
("tan", np.tan),
("asin", np.arcsin),
("acos", np.arccos),
("atan", np.arctan),
("atan2", np.arctan2),
("sinh", np.sinh),
("cosh", np.cosh),
("tanh", np.tanh),
("asinh", np.arcsinh),
("acosh", np.arccosh),
("atanh", np.arctanh),
("ldexp", np.ldexp),
("floor", np.floor),
("ceil", np.ceil),
("trunc", np.trunc),
("hypot", np.hypot),
("exp2", np.exp2),
("log2", np.log2),
("fabs", dpnp.fabs),
("exp", dpnp.exp),
("log", dpnp.log),
("log10", dpnp.log10),
("expm1", dpnp.expm1),
("log1p", dpnp.log1p),
("sqrt", dpnp.sqrt),
("sin", dpnp.sin),
("cos", dpnp.cos),
("tan", dpnp.tan),
("asin", dpnp.arcsin),
("acos", dpnp.arccos),
("atan", dpnp.arctan),
("atan2", dpnp.arctan2),
("sinh", dpnp.sinh),
("cosh", dpnp.cosh),
("tanh", dpnp.tanh),
("asinh", dpnp.arcsinh),
("acosh", dpnp.arccosh),
("atanh", dpnp.arctanh),
("floor", dpnp.floor),
("ceil", dpnp.ceil),
("trunc", dpnp.trunc),
("hypot", dpnp.hypot),
("exp2", dpnp.exp2),
("log2", dpnp.log2),
]

for name, ufunc in ufuncs:
Expand All @@ -344,23 +343,24 @@ def replace_numpy_ufunc_with_opencl_supported_functions(self):
def load_additional_registries(self):
"""Register OpenCL functions into numba_depx's target context.
To make sure we are calling supported OpenCL math functions, we
replace some of NUMBA's NumPy ufunc with OpenCL versions of those
functions. The replacement is done after the OpenCL functions have
been registered into the target context.
To make sure we are calling supported OpenCL math functions, we replace
the dpnp functions that default to NUMBA's NumPy ufunc with OpenCL
intrinsics that are equivalent to those functions. The replacement is
done after the OpenCL functions have been registered into the
target context.
"""
from numba.np import npyimpl
from numba_dpex.dpnp_iface import dpnpimpl

from ... import printimpl
from ...ocl import mathimpl, oclimpl

self.insert_func_defn(oclimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
self.insert_func_defn(npyimpl.registry.functions)
self.insert_func_defn(dpnpimpl.registry.functions)
self.install_registry(printimpl.registry)
# Replace NumPy functions with their OpenCL versions.
self.replace_numpy_ufunc_with_opencl_supported_functions()
# Replace dpnp math functions with their OpenCL versions.
self.replace_dpnp_ufunc_with_ocl_intrinsics()

@cached_property
def call_conv(self):
Expand Down
31 changes: 22 additions & 9 deletions numba_dpex/dpnp_iface/dpnpimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@
#
# SPDX-License-Identifier: Apache-2.0

import copy

import dpnp
from numba.core.imputils import Registry
from numba.np import npyimpl

from numba_dpex.core.typing.dpnpdecl import _unsupported
from numba_dpex.dpnp_iface import dpnp_ufunc_db

registry = Registry("dpnpimpl")


def _register_dpnp_ufuncs():
registry = Registry("npyimpl")
lower = registry.lower
"""Adds dpnp ufuncs to the dpnpimpl.registry.
The npyimpl.registry is searched for all registered ufuncs and we copy the
implementations and register them in a dpnp-specific registry defined in the
current module. The numpy ufuncs are deep copied so as to not mutate the
original functions by changes we introduce in the DpexKernelTarget.
Raises:
RuntimeError: If the signature of the ufunc takes more than two input
args.
"""
kernels = {}
# NOTE: Assuming ufunc implementation for the CPUContext.

for ufunc in dpnp_ufunc_db.get_ufuncs():
kernels[ufunc] = npyimpl.register_ufunc_kernel(
ufunc,
npyimpl._ufunc_db_function(ufunc),
lower,
copy.copy(npyimpl._ufunc_db_function(ufunc)),
registry.lower,
)

for _op_map in (
Expand All @@ -34,11 +46,11 @@ def _register_dpnp_ufuncs():
kernel = kernels[ufunc]
if ufunc.nin == 1:
npyimpl.register_unary_operator_kernel(
operator, ufunc, kernel, lower
operator, ufunc, kernel, registry.lower
)
elif ufunc.nin == 2:
npyimpl.register_binary_operator_kernel(
operator, ufunc, kernel, lower
operator, ufunc, kernel, registry.lower
)
else:
raise RuntimeError(
Expand All @@ -53,16 +65,17 @@ def _register_dpnp_ufuncs():
kernel = kernels[ufunc]
if ufunc.nin == 1:
npyimpl.register_unary_operator_kernel(
operator, ufunc, kernel, lower, inplace=True
operator, ufunc, kernel, registry.lower, inplace=True
)
elif ufunc.nin == 2:
npyimpl.register_binary_operator_kernel(
operator, ufunc, kernel, lower, inplace=True
operator, ufunc, kernel, registry.lower, inplace=True
)
else:
raise RuntimeError(
"There shouldn't be any non-unary or binary operators"
)


# Initialize the registry that stores the dpnp ufuncs
_register_dpnp_ufuncs()

0 comments on commit 0bf1547

Please sign in to comment.