Skip to content

Commit

Permalink
Remove spirv-tool call generators from spirv_generator.py.
Browse files Browse the repository at this point in the history
    - The spirv-tool package provides various utilities to
      disassemble, validate and potentially optimize a
      SPIR-V binary file. Numba-dpex no longer has spirv-tools
      as a dependency and it is up to the user if they want to
      try out the funtionalities provided by spirv-tools.

      Having the option to call spirv-tool utilities from
      spirv_generator is unused and potentially confusing. The
      commit removes all such options and renames the helper
      class from CmdLine to _SprivGenerator.

    - Also removed the config option to run spirv validator.
  • Loading branch information
Diptorup Deb committed Jan 23, 2024
1 parent 92a7a7d commit d9eb596
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 70 deletions.
4 changes: 0 additions & 4 deletions numba_dpex/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os

import dpctl
from numba.core import config


Expand Down Expand Up @@ -54,9 +53,6 @@ def __getattr__(name):
# To save intermediate files generated by th compiler
SAVE_IR_FILES = _readenv("NUMBA_DPEX_SAVE_IR_FILES", int, 0)

# Turn SPIRV-VALIDATION ON/OFF switch
SPIRV_VAL = _readenv("NUMBA_DPEX_SPIRV_VAL", int, 0)

# Dump offload diagnostics
OFFLOAD_DIAGNOSTICS = _readenv("NUMBA_DPEX_OFFLOAD_DIAGNOSTICS", int, 0)

Expand Down
75 changes: 9 additions & 66 deletions numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

"""A wrapper to connect to the SPIR-V binaries (Tools, Translator)."""
"""
A wrapper to call dpcpp's llvm-spirv tool to generate a SPIR-V binary from
a numba-dpex generated LLVM IR module.
"""

import os
import tempfile
Expand All @@ -28,47 +31,8 @@ def run_cmd(args, error_message=None):
)


class CmdLine:
def disassemble(self, ipath, opath):
"""
Disassemble a spirv module.
Args:
ipath: Input file path of the spirv module.
opath: Output file path of the disassembled spirv module.
"""
flags = []
run_cmd(
["spirv-dis", *flags, "-o", opath, ipath],
error_message="Error during SPIRV disassemble",
)

def validate(self, ipath):
"""
Validate a spirv module.
Args:
ipath: Input file path of the spirv module.
"""
flags = []
run_cmd(
["spirv-val", *flags, ipath],
error_message="Error during SPIRV validation",
)

def optimize(self, ipath, opath):
"""
Optimize a spirv module.
Args:
ipath: Input file path of the spirv module.
opath: Output file path of the optimized spirv module.
"""
flags = []
run_cmd(
["spirv-opt", *flags, "-o", opath, ipath],
error_message="Error during SPIRV optimization",
)
class _SpirvGenerator:
"""Generates a SPIR-V binary from supplied LLVM IR."""

def generate(self, llvm_spirv_args, ipath, opath):
"""
Expand Down Expand Up @@ -115,7 +79,7 @@ def __init__(self, context, llvmir, llvmbc):
"""
self._tmpdir = tempfile.mkdtemp()
self._tempfiles = []
self._cmd = CmdLine()
self._generator = _SpirvGenerator()
self._finalized = False
self.context = context

Expand Down Expand Up @@ -182,7 +146,7 @@ def finalize(self):
print("generated_llvm.bc")
print("".center(80, "="))

self._cmd.generate(
self._generator.generate(
llvm_spirv_args=llvm_spirv_args,
ipath=self._llvmfile,
opath=spirv_path,
Expand All @@ -199,28 +163,7 @@ def finalize(self):
print("generated_spirv.spir")
print("".center(80, "="))

# Validate the SPIR-V code
if config.SPIRV_VAL == 1:
try:
self._cmd.validate(ipath=spirv_path)
except CalledProcessError:
print("SPIR-V Validation failed...")
pass
else:
# Optimize SPIR-V code
opt_path = self._track_temp_file("optimized-spirv")
self._cmd.optimize(ipath=spirv_path, opath=opt_path)

if config.DUMP_ASSEMBLY:
# Disassemble optimized SPIR-V code
dis_path = self._track_temp_file("disassembled-spirv")
self._cmd.disassemble(ipath=opt_path, opath=dis_path)
with open(dis_path, "rb") as fin_opt:
print("ASSEMBLY".center(80, "-"))
print(fin_opt.read())
print("".center(80, "="))

# Read and return final SPIR-V (not optimized!)
# Read and return final SPIR-V
with open(spirv_path, "rb") as fin:
spirv = fin.read()

Expand Down

0 comments on commit d9eb596

Please sign in to comment.