Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output LLVM IR if NUMBA_DPEX_DEBUG environment variable is set #924

Merged
merged 4 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def __getattr__(name):
"NUMBA_DPEX_DEBUGINFO", int, config.DEBUGINFO_DEFAULT
)

# Emit LLVM assembly language format(.ll)
DUMP_KERNEL_LLVM = _readenv(
"NUMBA_DPEX_DUMP_KERNEL_LLVM", int, config.DUMP_OPTIMIZED
)

# configs for caching
# To see the debug messages for the caching.
# Execute like:
Expand Down
16 changes: 15 additions & 1 deletion numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from numba.core import ir

from numba_dpex import spirv_generator
from numba_dpex import config, spirv_generator
from numba_dpex.core.compiler import compile_with_dpex
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError

Expand Down Expand Up @@ -139,6 +139,20 @@ def compile(
self._llvm_module = kernel.module.__str__()
self._module_name = kernel.name

# Dump LLVM IR if DEBUG flag is set.
if config.DUMP_KERNEL_LLVM:
import hashlib

# Embed hash of module name in the output file name
# so that different kernels are written to separate files
with open(
"llvm_kernel_"
diptorupd marked this conversation as resolved.
Show resolved Hide resolved
+ hashlib.sha256(self._module_name.encode()).hexdigest()
+ ".ll",
"w",
) as f:
f.write(self._llvm_module)

# FIXME: There is no need to serialize the bitcode. It can be passed to
# llvm-spirv directly via stdin.

Expand Down
84 changes: 84 additions & 0 deletions numba_dpex/tests/test_dump_kernel_llvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#! /usr/bin/env python

# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import hashlib
import os

import numba_dpex as dpex
from numba_dpex import config, float32, usm_ndarray
from numba_dpex.core.descriptor import dpex_kernel_target

f32arrty = usm_ndarray(ndim=1, dtype=float32, layout="C")


def _get_kernel_llvm(fn, sig, debug=False):
kernel = dpex.core.kernel_interface.spirv_kernel.SpirvKernel(
fn, fn.__name__
)
kernel.compile(
args=sig,
target_ctx=dpex_kernel_target.target_context,
typing_ctx=dpex_kernel_target.typing_context,
debug=debug,
compile_flags=None,
)
return kernel.module_name, kernel.llvm_module


def test_dump_file_on_dump_kernel_llvm_flag_on():
"""
Test functionality of DUMP_KERNEL_LLVM config variable.
Check llvm source is dumped in .ll file in current directory
and compare with llvm source stored in SprivKernel.
"""

def data_parallel_sum(var_a, var_b, var_c):
i = dpex.get_global_id(0)
var_c[i] = var_a[i] + var_b[i]

sig = (f32arrty, f32arrty, f32arrty)

config.DUMP_KERNEL_LLVM = True

llvm_module_name, llvm_module_str = _get_kernel_llvm(data_parallel_sum, sig)

dump_file_name = (
"llvm_kernel_"
+ hashlib.sha256(llvm_module_name.encode()).hexdigest()
+ ".ll"
)

with open(dump_file_name, "r") as f:
llvm_dump = f.read()

assert llvm_module_str == llvm_dump

os.remove(dump_file_name)


def test_no_dump_file_on_dump_kernel_llvm_flag_off():
"""
Test functionality of DUMP_KERNEL_LLVM config variable.
Check llvm source is not dumped in .ll file in current directory.
"""

def data_parallel_sum(var_a, var_b, var_c):
i = dpex.get_global_id(0)
var_c[i] = var_a[i] + var_b[i]

sig = (f32arrty, f32arrty, f32arrty)

config.DUMP_KERNEL_LLVM = False

llvm_module_name, llvm_module_str = _get_kernel_llvm(data_parallel_sum, sig)

dump_file_name = (
"llvm_kernel_"
+ hashlib.sha256(llvm_module_name.encode()).hexdigest()
+ ".ll"
)

assert not os.path.isfile(dump_file_name)