-
Notifications
You must be signed in to change notification settings - Fork 710
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adapt tensorrt llm custom all reduce to sgl-kernel (#2481)
Co-authored-by: Yineng Zhang <[email protected]>
- Loading branch information
1 parent
5f2595b
commit e04d3f2
Showing
13 changed files
with
872 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,75 @@ | ||
cmake_minimum_required(VERSION 3.18) | ||
project(sgl-kernel LANGUAGES CXX CUDA) | ||
|
||
# Basic settings | ||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
|
||
set(CMAKE_CUDA_STANDARD 17) | ||
set(CMAKE_CUDA_STANDARD_REQUIRED ON) | ||
|
||
find_package(PythonInterp 3 REQUIRED) | ||
find_package(PythonLibs 3 REQUIRED) | ||
# Set CUDA architectures | ||
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90") | ||
message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") | ||
|
||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED) | ||
|
||
# Find PyTorch | ||
execute_process( | ||
COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" | ||
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" | ||
OUTPUT_VARIABLE TORCH_CMAKE_PATH | ||
OUTPUT_STRIP_TRAILING_WHITESPACE | ||
) | ||
|
||
message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") | ||
message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}") | ||
|
||
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") | ||
|
||
find_package(Torch REQUIRED) | ||
|
||
include_directories(${PYTHON_INCLUDE_DIRS}) | ||
|
||
# Warp Reduce library | ||
add_library(warp_reduce SHARED | ||
src/sgl-kernel/csrc/warp_reduce.cc | ||
src/sgl-kernel/csrc/warp_reduce_kernel.cu | ||
) | ||
|
||
target_include_directories(warp_reduce PRIVATE | ||
${CUDA_INCLUDE_DIRS} | ||
${TORCH_INCLUDE_DIRS} | ||
target_include_directories(warp_reduce | ||
PRIVATE | ||
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc | ||
${CUDA_INCLUDE_DIRS} | ||
${TORCH_INCLUDE_DIRS} | ||
) | ||
|
||
target_link_libraries(warp_reduce PRIVATE | ||
${TORCH_LIBRARIES} | ||
${PYTHON_LIBRARIES} | ||
target_link_libraries(warp_reduce | ||
PRIVATE | ||
${TORCH_LIBRARIES} | ||
Python3::Python | ||
) | ||
|
||
set_target_properties(warp_reduce PROPERTIES | ||
CUDA_SEPARABLE_COMPILATION ON | ||
# TRT Reduce library | ||
add_library(trt_reduce SHARED | ||
src/sgl-kernel/csrc/trt_reduce.cc | ||
src/sgl-kernel/csrc/trt_reduce_internal.cu | ||
src/sgl-kernel/csrc/trt_reduce_kernel.cu | ||
) | ||
|
||
target_include_directories(trt_reduce | ||
PRIVATE | ||
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc | ||
${CUDA_INCLUDE_DIRS} | ||
${TORCH_INCLUDE_DIRS} | ||
) | ||
|
||
target_link_libraries(trt_reduce | ||
PRIVATE | ||
${TORCH_LIBRARIES} | ||
Python3::Python | ||
) | ||
|
||
# Set common properties for both libraries | ||
foreach(target warp_reduce trt_reduce) | ||
set_target_properties(${target} PROPERTIES | ||
CUDA_SEPARABLE_COMPILATION ON | ||
POSITION_INDEPENDENT_CODE ON | ||
CUDA_RESOLVE_DEVICE_SYMBOLS ON | ||
PREFIX "" | ||
SUFFIX ".so" | ||
) | ||
endforeach() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
from .ops import warp_reduce | ||
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce | ||
|
||
__all__ = ["warp_reduce"] | ||
__all__ = [ | ||
"warp_reduce", | ||
"init_custom_reduce", | ||
"custom_dispose", | ||
"custom_reduce", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#include <torch/extension.h> | ||
|
||
using fptr_t = int64_t; | ||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers, | ||
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out); | ||
void dispose(fptr_t _fa); | ||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); | ||
m.def("dispose", &dispose, "dispose custom allreduce meta"); | ||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); | ||
} |
Oops, something went wrong.