Skip to content

Commit

Permalink
Move printimpl into kernel_api_impl.spirv
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Apr 4, 2024
1 parent e0d0772 commit a455e46
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

"""
An implementation of ``print`` for use in a kernel for the SPIRVKernelTarget.
"""

from functools import singledispatch

import llvmlite.ir as llvmir
Expand All @@ -14,7 +18,16 @@
lower = registry.lower


def declare_print(lmod):
def declare_print(lmod: llvmir.Module):
"""Inserts declaration for C printf into the given LLVM module
Args:
lmod (llvmir.Module): LLVM module into which the function declaration
needs to be inserted.
Returns:
An LLVM IR Function object for the inserted C printf function.
"""
voidptrty = llvmir.PointerType(
llvmir.IntType(8), addrspace=address_space.GENERIC.value
)
Expand All @@ -32,33 +45,34 @@ def print_item(ty, context, builder, val):
A (format string, [list of arguments]) is returned that will allow
forming the final printf()-like call.
"""
raise NotImplementedError(
"printing unimplemented for values of type %s" % (ty,)
)
raise NotImplementedError(f"printing unimplemented for values of type {ty}")


@print_item.register(types.Integer)
@print_item.register(types.IntegerLiteral)
def int_print_impl(ty, context, builder, val):
"""Implements printing an integer value."""
if ty in types.unsigned_domain:
rawfmt = "%llu"
dsttype = types.uint64
else:
rawfmt = "%lld"
dsttype = types.int64
fmt = context.insert_const_string(builder.module, rawfmt) # noqa
context.insert_const_string(builder.module, rawfmt)
lld = context.cast(builder, val, ty, dsttype)
return rawfmt, [lld]


@print_item.register(types.Float)
def real_print_impl(ty, context, builder, val):
"""Implements printing a real number value."""
lld = context.cast(builder, val, ty, types.float64)
return "%f", [lld]


@print_item.register(types.StringLiteral)
def const_print_impl(ty, context, builder, sigval):
"""Implements printing a string value."""
pyval = ty.literal_value
assert isinstance(pyval, str) # Ensured by lowering
rawfmt = "%s"
Expand All @@ -76,7 +90,7 @@ def print_varargs(context, builder, sig, args):
values = []

only_str = True
for i, (argtype, argval) in enumerate(zip(sig.args, args)):
for _, (argtype, argval) in enumerate(zip(sig.args, args)):
argfmt, argvals = print_item(argtype, context, builder, argval)
formats.append(argfmt)
values.extend(argvals)
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def load_additional_registries(self):
"""
# pylint: disable=import-outside-toplevel
from numba_dpex import printimpl
from numba_dpex.dpctl_iface import dpctlimpl
from numba_dpex.dpnp_iface import dpnpimpl
from numba_dpex.kernel_api_impl.spirv import printimpl
from numba_dpex.ocl import mathimpl

self.insert_func_defn(mathimpl.registry.functions)
Expand Down

0 comments on commit a455e46

Please sign in to comment.