diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index c635b75c348..adb81fa2b0f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -1,47 +1,75 @@ cmake_minimum_required(VERSION 3.18) project(sgl-kernel LANGUAGES CXX CUDA) +# Basic settings set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) -find_package(PythonInterp 3 REQUIRED) -find_package(PythonLibs 3 REQUIRED) +# Set CUDA architectures +set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90") +message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +# Find PyTorch execute_process( - COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" OUTPUT_VARIABLE TORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE ) - -message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") -message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}") - list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") find_package(Torch REQUIRED) -include_directories(${PYTHON_INCLUDE_DIRS}) - +# Warp Reduce library add_library(warp_reduce SHARED src/sgl-kernel/csrc/warp_reduce.cc src/sgl-kernel/csrc/warp_reduce_kernel.cu ) -target_include_directories(warp_reduce PRIVATE - ${CUDA_INCLUDE_DIRS} - ${TORCH_INCLUDE_DIRS} +target_include_directories(warp_reduce + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} ) -target_link_libraries(warp_reduce PRIVATE - ${TORCH_LIBRARIES} - ${PYTHON_LIBRARIES} +target_link_libraries(warp_reduce + PRIVATE + ${TORCH_LIBRARIES} + Python3::Python ) -set_target_properties(warp_reduce PROPERTIES - CUDA_SEPARABLE_COMPILATION ON +# TRT Reduce library +add_library(trt_reduce SHARED + src/sgl-kernel/csrc/trt_reduce.cc + src/sgl-kernel/csrc/trt_reduce_internal.cu + src/sgl-kernel/csrc/trt_reduce_kernel.cu ) + +target_include_directories(trt_reduce + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} +) + +target_link_libraries(trt_reduce + PRIVATE + ${TORCH_LIBRARIES} + Python3::Python +) + +# Set common properties for both libraries +foreach(target warp_reduce trt_reduce) + set_target_properties(${target} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + POSITION_INDEPENDENT_CODE ON + CUDA_RESOLVE_DEVICE_SYMBOLS ON + PREFIX "" + SUFFIX ".so" + ) +endforeach() diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index ec86d684877..7a041b1ed40 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -10,7 +10,7 @@ install: @pip install -e . build: - @python3 setup.py bdist_wheel + @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel clean: @rm -rf build dist *.egg-info @@ -19,4 +19,4 @@ test: @pytest tests/ format: - @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black + @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 4fbdd9dae23..a9111119807 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post4" +version = "0.0.2.post5" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 724c634bc25..d81e9da00e8 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -84,7 +84,31 @@ def update_wheel_platform_tag(): }, libraries=["c10", "torch", "torch_python"], extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ) + ), + CUDAExtension( + "sgl_kernel.ops.custom_reduce_cuda", + [ + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/trt_reduce.cc", + ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], + "cxx": ["-O3"], + }, + libraries=["c10", "torch", "torch_python"], + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), ], cmdclass={"build_ext": BuildExtension}, install_requires=["torch"], diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index edf3921db79..9876d4e5ee2 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,3 +1,8 @@ -from .ops import warp_reduce +from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce -__all__ = ["warp_reduce"] +__all__ = [ + "warp_reduce", + "init_custom_reduce", + "custom_dispose", + "custom_reduce", +] diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc new file mode 100644 index 00000000000..4d8f732af3e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc @@ -0,0 +1,13 @@ +#include + +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); + m.def("dispose", &dispose, "dispose custom allreduce meta"); + m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu new file mode 100644 index 00000000000..04393c8e716 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -0,0 +1,282 @@ +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "trt_reduce_internal.cuh" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + +namespace trt_llm { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Type Converter that packs data format to 128 bits data type +// +using PackedFloat = union { + int4 packed; + float unpacked[4]; +}; + +using PackedHalf = union { + int4 packed; + half2 unpacked[4]; +}; + +template +struct PackedOn16Bytes {}; + +template <> +struct PackedOn16Bytes { + using Type = PackedFloat; +}; + +template <> +struct PackedOn16Bytes { + using Type = PackedHalf; +}; + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +using PackedBFloat16 = union { + int4 packed; + __nv_bfloat162 unpacked[4]; +}; + +template <> +struct PackedOn16Bytes<__nv_bfloat16> { + using Type = PackedBFloat16; +}; +#endif + +// add two 128b data +template +inline __device__ int4 add128b(T& a, T& b) { + T c; + c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; + c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; + c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; + c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; + return c.packed; +} + +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, + size_t const world_size, int const tidx, int const bidx) { + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + size_t offset = (flag % 2) ? world_size : 0; + + if (bidx == 0) { + st_flag_release(flag, signals[tidx] + offset + local_rank); + } + + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } + + __syncthreads(); +} + +template /* COPY_INPUT = false, PUSH_MODE = false */ +static __global__ void oneShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start four blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // GPU 0 | B0 | B1 | B2 | B3 | + // GPU 1 | B0 | B1 | B2 | B3 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) + // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunk is it responsible for into all other GPUs: + // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] + // 2. block sync so the block is shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = 16 / sizeof(T); + + // Packed data type for comms + using PackedStruct = typename PackedOn16Bytes::Type; + + // The source pointers. Distributed round-robin for the different warps. + T const* buffers[RANKS_PER_NODE]; + + // Start and end offsets of the thread + size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; + size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + + multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedStruct vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + } + + // Sum the values from the different ranks. + PackedStruct sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); + } + + // Store to the destination buffer. + *reinterpret_cast(&reinterpret_cast(params.local_output_buffer_ptr)[iter_offset]) = sums.packed; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int divUp(int a, int b) { + return (a + b - 1) / b; +} + +inline int roundUp(int a, int n) { + return divUp(a, n) * n; +} + +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) { + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + assert(params.elts_total % elts_per_thread == 0); + size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE); + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread); + params.elts_per_rank = params.elts_total; + break; + } + default: + assert(false && "Algorithm not supported here."); + } + + return std::make_tuple(blocks_per_grid, threads_per_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, + cudaStream_t stream) { + oneShotAllReduceKernel<<>>(param); +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { + void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); + void* local_inp_buffer = param.local_input_buffer_ptr; + CHECK_CUDA_SUCCESS( + cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); + + assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); + CHECK_CUDA_SUCCESS(cudaGetLastError()); + + size_t elts_per_thread = 16 / sizeof(T); + auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); + switch (param.ranks_per_node) { + case 2: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 4: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 6: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 8: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + default: + break; + } + CHECK_CUDA_SUCCESS(cudaGetLastError()); +} + +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, + cudaStream_t stream) { + if (params.elts_total == 0) { + return; + } + + switch (data_type) { + case at::ScalarType::Float: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; + case at::ScalarType::Half: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: + invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); + break; +#endif + default: + assert(false && "Unsupported data type"); + } +} +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh new file mode 100644 index 00000000000..01652a22a86 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -0,0 +1,91 @@ +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include "utils.hpp" + +namespace trt_llm { +constexpr size_t WARP_SIZE = 32; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; +constexpr size_t MAX_RANKS_PER_NODE = 8; +constexpr size_t DEFAULT_BLOCK_SIZE = 1024; + +enum class AllReduceStrategyType : int8_t { + RING = 0, + ONESHOT = 1, + TWOSHOT = 2, + AUTO = 3, +}; + +struct AllReduceParams { + size_t elts_size; + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t ranks_per_node, rank, local_rank; + uint32_t barrier_flag; + uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; + uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; + void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + void* local_input_buffer_ptr; + void* local_output_buffer_ptr; +}; + +inline size_t GetMaxRequiredWorkspaceSize(int world_size) { + if (world_size <= 2) { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) { + const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); + + if (message_size > maxWorkspaceSize) { + assert(false && "Custom allreduce do not ring currently"); + return AllReduceStrategyType::RING; + } + + if (world_size <= 2) { + return AllReduceStrategyType::ONESHOT; + } + + if (world_size <= 4) { + if (message_size < 1 * 1000 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; + } + + if (message_size < 500 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; +} + +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, + cudaStream_t stream); + +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu new file mode 100644 index 00000000000..2a2dcebc89e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -0,0 +1,102 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h + +#include + +#include +#include +#include +#include + +#include "trt_reduce_internal.cuh" + +using namespace trt_llm; + +using fptr_t = int64_t; + +class AllReduceMeta { + public: + AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out) { + this->rank_id = (int)rank_id; + this->world_size = (int)world_size; + this->buffers = buffers; + this->barrier_in = barrier_in; + this->barrier_out = barrier_out; + } + + public: + int world_size; + int rank_id; + std::vector buffers; + std::vector barrier_in; + std::vector barrier_out; + int barrier_flag = 1; +}; + +// Get the number of bits for a given data type. +inline int get_bits(at::ScalarType dtype) { + switch (dtype) { + case at::ScalarType::Float: + return 32; + case at::ScalarType::Half: + case at::ScalarType::BFloat16: + return 16; + default: + assert(false && "Unsupported data type"); + } +} + +// Check if customized all-reduce kernels can be applied. +inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) { + // The customized all-reduce kernel has the following requirement(s). + return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; +} + +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); + return (fptr_t)m; +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + auto num_elements = inp.numel(); + auto dtype = inp.scalar_type(); + AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); + + // should be gurantee in python code + assert(strategy == AllReduceStrategyType::ONESHOT); + assert(CanApplyCustomAllReduce(num_elements, dtype)); + + // Initialize the all-reduce kernel arguments. + int world_size = m->world_size; + + AllReduceParams params; + params.ranks_per_node = world_size; + params.rank = m->rank_id; + params.local_rank = m->rank_id; + params.local_input_buffer_ptr = inp.data_ptr(); + params.local_output_buffer_ptr = out.data_ptr(); + params.elts_total = inp.numel(); + params.elts_size = inp.element_size(); + params.barrier_flag = ++(m->barrier_flag); + + for (int i = 0; i < world_size; ++i) { + params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); + } + + auto data_type = out.scalar_type(); + trtCustomAllReduce(params, data_type, strategy, stream); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp new file mode 100644 index 00000000000..bbdc6311be9 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp @@ -0,0 +1,36 @@ +#pragma once +#include + +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc index efc2f0cd951..379b4cc15bf 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc @@ -1,15 +1,11 @@ #include -torch::Tensor warp_reduce_cuda(torch::Tensor input); +#include "utils.hpp" -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) +torch::Tensor warp_reduce_cuda(torch::Tensor input); torch::Tensor warp_reduce(torch::Tensor input) { - CHECK_INPUT(input); + CHECK_CUDA_INPUT(input); return warp_reduce_cuda(input); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 21870032e5a..7a3ceb2bd53 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,5 +1,20 @@ +from .custom_reduce_cuda import all_reduce as _all_reduce +from .custom_reduce_cuda import dispose as _dispose +from .custom_reduce_cuda import init_custom_ar as _init_custom_ar from .warp_reduce_cuda import reduce as _reduce def warp_reduce(input_tensor): return _reduce(input_tensor) + + +def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): + return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) + + +def custom_dispose(fa): + _dispose(fa) + + +def custom_reduce(fa, inp, out): + _all_reduce(fa, inp, out) diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py new file mode 100644 index 00000000000..09a11908018 --- /dev/null +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -0,0 +1,248 @@ +import ctypes +import logging +import os +import random +import socket +import time +import unittest +from typing import Any, List, Optional, Union + +import ray +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm import _custom_ops as vllm_ops + +from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + +logger = logging.getLogger(__name__) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + cls.test_sizes = { + 2: [512, 4096, 32768, 262144, 2097152], + 4: [512, 4096, 32768, 131072], + 6: [512, 4096, 32768, 65536], + 8: [512, 4096, 32768, 65536], + } + cls.world_sizes = [2, 4, 6, 8] + + @staticmethod + def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None + ) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None + ) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + def test_correctness(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.correctness) + + def test_performance(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.performance) + + def init_custom_allreduce(self, rank, world_size, group): + import sgl_kernel + + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) + self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + + self.custom_ptr = sgl_kernel.ops.init_custom_reduce( + rank, + world_size, + self.buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + + def custom_allreduce(self, inp, out): + import sgl_kernel + + sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out) + + def free_custom_allreduce(self, group): + import sgl_kernel + + self.free_shared_buffer(self.buffer_ptrs, group) + self.free_shared_buffer(self.barrier_in_ptrs, group) + self.free_shared_buffer(self.barrier_out_ptrs, group) + sgl_kernel.ops.custom_dispose(self.custom_ptr) + + def init_vllm_allreduce(self, rank, group): + self.vllm_rank = rank + self.vllm_max_size = 8 * 1024 * 1024 + self.vllm_meta_ptrs = self.create_shared_buffer( + vllm_ops.meta_size() + self.vllm_max_size, group=group + ) + self.vllm_buffer_ptrs = self.create_shared_buffer( + self.vllm_max_size, group=group + ) + self.vllm_rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) + self.vllm_ptr = vllm_ops.init_custom_ar( + self.vllm_meta_ptrs, self.vllm_rank_data, rank, True + ) + vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs) + + def vllm_allreduce(self, inp, out): + vllm_ops.all_reduce( + self.vllm_ptr, + inp, + out, + self.vllm_buffer_ptrs[self.vllm_rank], + self.vllm_max_size, + ) + + def free_vllm_allreduce(self, group): + vllm_ops.dispose(self.vllm_ptr) + self.free_shared_buffer(self.vllm_meta_ptrs, group) + self.free_shared_buffer(self.vllm_buffer_ptrs, group) + + @staticmethod + def init_distributed_env(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = [i for i in range(world_size)] + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = torch.distributed.new_group(ranks, backend="gloo") + return group + + # compare result with torch.distributed + @ray.remote(num_gpus=1, max_calls=1) + def correctness(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + test_loop = 10 + for sz in self.test_sizes[world_size]: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + self.custom_allreduce(inp1, out1) + + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + self.free_custom_allreduce(group) + + # compare performance with vllm + @ray.remote(num_gpus=1, max_calls=1) + def performance(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_vllm_allreduce(rank, group) + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + for sz in self.test_sizes[world_size]: + inp1 = torch.randint( + 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + test_loop = 5000 + start = time.time() + for _ in range(test_loop): + self.custom_allreduce(inp1, out1) + elapse_custom = time.time() - start + + start = time.time() + for _ in range(test_loop): + self.vllm_allreduce(inp1, out1) + elapse_vllm = time.time() - start + + if rank == 0: + logger.warning( + f"test_size = {sz}, world_size = {world_size}, " + f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}us," + f"custom time = {elapse_custom * 1000 / test_loop:.4f}us" + ) + + self.free_custom_allreduce(group) + self.free_vllm_allreduce(group) + + +if __name__ == "__main__": + unittest.main()