Skip to content

Commit

Permalink
Fix debug flag propagation for func
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Jan 20, 2023
1 parent dd6c1d3 commit be77154
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
20 changes: 10 additions & 10 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class DpexFunction(object):
the function is invoked.
"""

def __init__(self, pyfunc, debug=None):
def __init__(self, pyfunc, debug=False):
"""Constructor for `DpexFunction`
Args:
pyfunc (`function`): A python function to be compiled.
debug (`object`, optional): Debug option for compilation.
Defaults to `None`.
debug (`bool`, optional): Debug option for compilation.
Defaults to `False`.
"""
self._pyfunc = pyfunc
self._debug = debug
Expand Down Expand Up @@ -77,13 +77,13 @@ class DpexFunctionTemplate(object):
calling convention.
"""

def __init__(self, pyfunc, debug=None, enable_cache=True):
def __init__(self, pyfunc, debug=False, enable_cache=True):
"""Constructor for `DpexFunctionTemplate`
Args:
pyfunc (function): A python function to be compiled.
debug (object, optional): Debug option for compilation.
Defaults to `None`.
debug (bool, optional): Debug option for compilation.
Defaults to `False`.
enable_cache (bool, optional): Flag to turn on/off caching.
Defaults to `True`.
"""
Expand Down Expand Up @@ -159,7 +159,7 @@ def compile(self, args):
return cres.signature


def compile_func(pyfunc, signature, debug=None):
def compile_func(pyfunc, signature, debug=False):
"""Compiles a specialized `numba_dpex.func`
Compiles a specialized `numba_dpex.func` decorated function to native
Expand All @@ -169,7 +169,7 @@ def compile_func(pyfunc, signature, debug=None):
Args:
pyfunc (`function`): A python function to be compiled.
signature (`list`): A list of `numba.core.typing.templates.Signature`'s
debug (`object`, optional): Debug options. Defaults to `None`.
debug (`bool`, optional): Debug options. Defaults to `False`.
Returns:
`numba_dpex.core.kernel_interface.func.DpexFunction`: A `DpexFunction` object
Expand Down Expand Up @@ -205,7 +205,7 @@ class _function_template(ConcreteTemplate):
return devfn


def compile_func_template(pyfunc, debug=None, enable_cache=True):
def compile_func_template(pyfunc, debug=False, enable_cache=True):
"""Converts a `numba_dpex.func` function to an `AbstractTemplate`
Converts a `numba_dpex.func` decorated function to a Numba
Expand All @@ -219,7 +219,7 @@ def compile_func_template(pyfunc, debug=None, enable_cache=True):
Args:
pyfunc (`function`): A python function to be compiled.
debug (`object`, optional): Debug options. Defaults to `None`.
debug (`bool`, optional): Debug options. Defaults to `False`.
Raises:
`AssertionError`: Raised if keyword arguments are supplied in
Expand Down
13 changes: 5 additions & 8 deletions numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def kernel(
func_or_sig=None,
access_types=None,
debug=None,
debug=False,
enable_cache=True,
):
"""A decorator to define a kernel function.
Expand Down Expand Up @@ -93,7 +93,7 @@ def _specialized_kernel_dispatcher(pyfunc):
return _kernel_dispatcher(func)


def func(func_or_sig=None, debug=None, enable_cache=True):
def func(func_or_sig=None, debug=False, enable_cache=True):
"""A decorator to define a kernel device function.
Device functions are functions that can be only invoked from a kernel
Expand All @@ -105,16 +105,13 @@ def func(func_or_sig=None, debug=None, enable_cache=True):
normal functions.
"""

def _func_autojit(pyfunc, debug=None):
def _func_autojit(pyfunc):
return compile_func_template(
pyfunc, debug=debug, enable_cache=enable_cache
)

def _func_autojit_wrapper(debug=None):
return _func_autojit

if func_or_sig is None:
return _func_autojit_wrapper(debug=debug)
return _func_autojit
elif isinstance(func_or_sig, str):
raise NotImplementedError(
"Specifying signatures as string is not yet supported by numba-dpex"
Expand Down Expand Up @@ -143,4 +140,4 @@ def _wrapped(pyfunc):
else:
# no signature
func = func_or_sig
return _func_autojit(func, debug=debug)
return _func_autojit(func)

0 comments on commit be77154

Please sign in to comment.