Skip to content

Commit

Permalink
Merge pull request #1255 from IntelPython/feature/spirv_numba_error
Browse files Browse the repository at this point in the history
Improve spirv error reporting in numba
  • Loading branch information
Diptorup Deb authored Dec 22, 2023
2 parents ec3ef13 + c75f48e commit 4259e1b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
14 changes: 14 additions & 0 deletions numba_dpex/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,17 @@ def __init__(self) -> None:
"An IntEnumLiteral can only be initialized from a FlagEnum member"
)
super().__init__(self.message)


class InternalError(Exception):
"""
Exception raise when something in the internal compiler machinery failed.
"""

def __init__(self, extra_msg=None) -> None:
if extra_msg is None:
self.message = "An internal compiler error happened and the "
"compilation failed."
else:
self.message = extra_msg
super().__init__(self.message)
48 changes: 31 additions & 17 deletions numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,26 @@

import os
import tempfile
from subprocess import CalledProcessError, check_call
from subprocess import STDOUT, CalledProcessError, check_output

from numba_dpex import config
from numba_dpex.core.exceptions import InternalError
from numba_dpex.core.targets.kernel_target import LLVM_SPIRV_ARGS


def _raise_bad_env_path(msg, path, extra=None):
error_message = msg.format(path)
if extra is not None:
error_message += extra
raise ValueError(error_message)


_real_check_call = check_call


def check_call(*args, **kwargs):
return _real_check_call(*args, **kwargs)
def run_cmd(args, error_message=None):
try:
check_output(
args,
stderr=STDOUT,
)
except CalledProcessError as err:
if error_message is None:
error_message = f"Error during call to {args[0]}"
raise InternalError(
f"{error_message}:\n\t"
+ "\t".join(err.output.decode("utf-8").splitlines(True))
)


class CmdLine:
Expand All @@ -36,7 +38,10 @@ def disassemble(self, ipath, opath):
opath: Output file path of the disassembled spirv module.
"""
flags = []
check_call(["spirv-dis", *flags, "-o", opath, ipath])
run_cmd(
["spirv-dis", *flags, "-o", opath, ipath],
error_message="Error during SPIRV disassemble",
)

def validate(self, ipath):
"""
Expand All @@ -46,7 +51,10 @@ def validate(self, ipath):
ipath: Input file path of the spirv module.
"""
flags = []
check_call(["spirv-val", *flags, ipath])
run_cmd(
["spirv-val", *flags, ipath],
error_message="Error during SPIRV validation",
)

def optimize(self, ipath, opath):
"""
Expand All @@ -57,7 +65,10 @@ def optimize(self, ipath, opath):
opath: Output file path of the optimized spirv module.
"""
flags = []
check_call(["spirv-opt", *flags, "-o", opath, ipath])
run_cmd(
["spirv-opt", *flags, "-o", opath, ipath],
error_message="Error during SPIRV optimization",
)

def generate(self, llvm_spirv_args, ipath, opath):
"""
Expand All @@ -79,7 +90,10 @@ def generate(self, llvm_spirv_args, ipath, opath):
if config.DEBUG:
print(f"Use llvm-spirv: {llvm_spirv_tool}")

check_call([llvm_spirv_tool, *llvm_spirv_args, "-o", opath, ipath])
run_cmd(
[llvm_spirv_tool, *llvm_spirv_args, "-o", opath, ipath],
error_message="Error during lowering LLVM IR to SPIRV",
)

@staticmethod
def _llvm_spirv():
Expand Down

0 comments on commit 4259e1b

Please sign in to comment.