diff --git a/setup.py b/setup.py index 61d2d710aa20e..ba6953dbdc174 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import ctypes import importlib.util import logging import os @@ -13,7 +14,7 @@ from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools_scm import get_version -from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME def load_module_from_path(module_name, path): @@ -379,25 +380,31 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def get_hipcc_rocm_version(): - # Run the hipcc --version command - result = subprocess.run(['hipcc', '--version'], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True) +def get_rocm_version(): + # Get the Rocm version from the ROCM_HOME/bin/librocm-core.so + # see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21 + try: + librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so" + if not librocm_core_file.is_file(): + return None + librocm_core = ctypes.CDLL(librocm_core_file) + VerErrors = ctypes.c_uint32 + get_rocm_core_version = librocm_core.getROCmVersion + get_rocm_core_version.restype = VerErrors + get_rocm_core_version.argtypes = [ + ctypes.POINTER(ctypes.c_uint32), + ctypes.POINTER(ctypes.c_uint32), + ctypes.POINTER(ctypes.c_uint32), + ] + major = ctypes.c_uint32() + minor = ctypes.c_uint32() + patch = ctypes.c_uint32() - # Check if the command was executed successfully - if result.returncode != 0: - print("Error running 'hipcc --version'") + if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), + ctypes.byref(patch)) == 0): + return "%d.%d.%d" % (major.value, minor.value, patch.value) return None - - # Extract the version using a regular expression - match = re.search(r'HIP version: (\S+)', result.stdout) - if match: - # Return the version string - return match.group(1) - else: - print("Could not find HIP version in the output") + except Exception: return None @@ -479,11 +486,10 @@ def get_vllm_version() -> str: if "sdist" not in sys.argv: version += f"{sep}cu{cuda_version_str}" elif _is_hip(): - # Get the HIP version - hipcc_version = get_hipcc_rocm_version() - if hipcc_version != MAIN_CUDA_VERSION: - rocm_version_str = hipcc_version.replace(".", "")[:3] - version += f"{sep}rocm{rocm_version_str}" + # Get the Rocm Version + rocm_version = get_rocm_version() or torch.version.hip + if rocm_version and rocm_version != MAIN_CUDA_VERSION: + version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" elif _is_neuron(): # Get the Neuron version neuron_version = str(get_neuronxcc_version())