Skip to content

Commit

Permalink
adapt tensorrt llm custom all reduce to sgl-kernel (#2481)
Browse files Browse the repository at this point in the history
Co-authored-by: Yineng Zhang <[email protected]>
  • Loading branch information
yizhang2077 and zhyncs authored Dec 15, 2024
1 parent 5f2595b commit e04d3f2
Show file tree
Hide file tree
Showing 13 changed files with 872 additions and 32 deletions.
66 changes: 47 additions & 19 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,47 +1,75 @@
cmake_minimum_required(VERSION 3.18)
project(sgl-kernel LANGUAGES CXX CUDA)

# Basic settings
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

find_package(PythonInterp 3 REQUIRED)
find_package(PythonLibs 3 REQUIRED)
# Set CUDA architectures
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90")
message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Find PyTorch
execute_process(
COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)

message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}")
message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}")

list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")

find_package(Torch REQUIRED)

include_directories(${PYTHON_INCLUDE_DIRS})

# Warp Reduce library
add_library(warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)

target_include_directories(warp_reduce PRIVATE
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
target_include_directories(warp_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

target_link_libraries(warp_reduce PRIVATE
${TORCH_LIBRARIES}
${PYTHON_LIBRARIES}
target_link_libraries(warp_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)

set_target_properties(warp_reduce PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
# TRT Reduce library
add_library(trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
)

target_include_directories(trt_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

target_link_libraries(trt_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)

# Set common properties for both libraries
foreach(target warp_reduce trt_reduce)
set_target_properties(${target} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
PREFIX ""
SUFFIX ".so"
)
endforeach()
4 changes: 2 additions & 2 deletions sgl-kernel/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ install:
@pip install -e .

build:
@python3 setup.py bdist_wheel
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel

clean:
@rm -rf build dist *.egg-info
Expand All @@ -19,4 +19,4 @@ test:
@pytest tests/

format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post4"
version = "0.0.2.post5"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
26 changes: 25 additions & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,31 @@ def update_wheel_platform_tag():
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
)
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
Expand Down
9 changes: 7 additions & 2 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .ops import warp_reduce
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce

__all__ = ["warp_reduce"]
__all__ = [
"warp_reduce",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
]
13 changes: 13 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <torch/extension.h>

using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
}
Loading

0 comments on commit e04d3f2

Please sign in to comment.