Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sgl-kernel adapt tensorrt llm custom allreduce #2481

Merged
merged 9 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,47 +1,75 @@
cmake_minimum_required(VERSION 3.18)
project(sgl-kernel LANGUAGES CXX CUDA)

# Basic settings
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

find_package(PythonInterp 3 REQUIRED)
find_package(PythonLibs 3 REQUIRED)
# Set CUDA architectures
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90")
message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Find PyTorch
execute_process(
COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)

message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}")
message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}")

list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")

find_package(Torch REQUIRED)

include_directories(${PYTHON_INCLUDE_DIRS})

# Warp Reduce library
add_library(warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)

target_include_directories(warp_reduce PRIVATE
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
target_include_directories(warp_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

target_link_libraries(warp_reduce PRIVATE
${TORCH_LIBRARIES}
${PYTHON_LIBRARIES}
target_link_libraries(warp_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)

set_target_properties(warp_reduce PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
# TRT Reduce library
add_library(trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
)

target_include_directories(trt_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

target_link_libraries(trt_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)

# Set common properties for both libraries
foreach(target warp_reduce trt_reduce)
set_target_properties(${target} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
PREFIX ""
SUFFIX ".so"
)
endforeach()
4 changes: 2 additions & 2 deletions sgl-kernel/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ install:
@pip install -e .

build:
@python3 setup.py bdist_wheel
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel

clean:
@rm -rf build dist *.egg-info
Expand All @@ -19,4 +19,4 @@ test:
@pytest tests/

format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post4"
version = "0.0.2.post5"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
26 changes: 25 additions & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,31 @@ def update_wheel_platform_tag():
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
)
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
Expand Down
9 changes: 7 additions & 2 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .ops import warp_reduce
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce

__all__ = ["warp_reduce"]
__all__ = [
"warp_reduce",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
]
13 changes: 13 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <torch/extension.h>

using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
}
Loading
Loading