From a455e46d6a126d3c57954a3359d4b78d0c6285a4 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 30 Mar 2024 20:47:50 -0500 Subject: [PATCH] Move printimpl into kernel_api_impl.spirv --- .../{ => kernel_api_impl/spirv}/printimpl.py | 26 ++++++++++++++----- numba_dpex/kernel_api_impl/spirv/target.py | 2 +- 2 files changed, 21 insertions(+), 7 deletions(-) rename numba_dpex/{ => kernel_api_impl/spirv}/printimpl.py (77%) diff --git a/numba_dpex/printimpl.py b/numba_dpex/kernel_api_impl/spirv/printimpl.py similarity index 77% rename from numba_dpex/printimpl.py rename to numba_dpex/kernel_api_impl/spirv/printimpl.py index 2bb53cfd81..486a385eb8 100644 --- a/numba_dpex/printimpl.py +++ b/numba_dpex/kernel_api_impl/spirv/printimpl.py @@ -2,6 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +""" +An implementation of ``print`` for use in a kernel for the SPIRVKernelTarget. +""" + from functools import singledispatch import llvmlite.ir as llvmir @@ -14,7 +18,16 @@ lower = registry.lower -def declare_print(lmod): +def declare_print(lmod: llvmir.Module): + """Inserts declaration for C printf into the given LLVM module + + Args: + lmod (llvmir.Module): LLVM module into which the function declaration + needs to be inserted. + + Returns: + An LLVM IR Function object for the inserted C printf function. + """ voidptrty = llvmir.PointerType( llvmir.IntType(8), addrspace=address_space.GENERIC.value ) @@ -32,33 +45,34 @@ def print_item(ty, context, builder, val): A (format string, [list of arguments]) is returned that will allow forming the final printf()-like call. """ - raise NotImplementedError( - "printing unimplemented for values of type %s" % (ty,) - ) + raise NotImplementedError(f"printing unimplemented for values of type {ty}") @print_item.register(types.Integer) @print_item.register(types.IntegerLiteral) def int_print_impl(ty, context, builder, val): + """Implements printing an integer value.""" if ty in types.unsigned_domain: rawfmt = "%llu" dsttype = types.uint64 else: rawfmt = "%lld" dsttype = types.int64 - fmt = context.insert_const_string(builder.module, rawfmt) # noqa + context.insert_const_string(builder.module, rawfmt) lld = context.cast(builder, val, ty, dsttype) return rawfmt, [lld] @print_item.register(types.Float) def real_print_impl(ty, context, builder, val): + """Implements printing a real number value.""" lld = context.cast(builder, val, ty, types.float64) return "%f", [lld] @print_item.register(types.StringLiteral) def const_print_impl(ty, context, builder, sigval): + """Implements printing a string value.""" pyval = ty.literal_value assert isinstance(pyval, str) # Ensured by lowering rawfmt = "%s" @@ -76,7 +90,7 @@ def print_varargs(context, builder, sig, args): values = [] only_str = True - for i, (argtype, argval) in enumerate(zip(sig.args, args)): + for _, (argtype, argval) in enumerate(zip(sig.args, args)): argfmt, argvals = print_item(argtype, context, builder, argval) formats.append(argfmt) values.extend(argvals) diff --git a/numba_dpex/kernel_api_impl/spirv/target.py b/numba_dpex/kernel_api_impl/spirv/target.py index 6ef86e5f91..689fc45514 100644 --- a/numba_dpex/kernel_api_impl/spirv/target.py +++ b/numba_dpex/kernel_api_impl/spirv/target.py @@ -291,9 +291,9 @@ def load_additional_registries(self): """ # pylint: disable=import-outside-toplevel - from numba_dpex import printimpl from numba_dpex.dpctl_iface import dpctlimpl from numba_dpex.dpnp_iface import dpnpimpl + from numba_dpex.kernel_api_impl.spirv import printimpl from numba_dpex.ocl import mathimpl self.insert_func_defn(mathimpl.registry.functions)