From 1f18dd3154e4d4e7087f73d5462613862f8a1c5e Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 14 Dec 2024 23:34:48 +0800 Subject: [PATCH 1/8] sgl-kernel support tensorrt llm custom allreduce --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/setup.py | 26 +- sgl-kernel/src/sgl-kernel/__init__.py | 9 +- sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc | 15 + .../sgl-kernel/csrc/trt_reduce_internal.cu | 326 ++++++++++++++++++ .../sgl-kernel/csrc/trt_reduce_internal.cuh | 111 ++++++ .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 107 ++++++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 15 + 8 files changed, 607 insertions(+), 4 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 4fbdd9dae23..ae20bcd8baa 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post4" +version = "0.0.3.post4" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 724c634bc25..d81e9da00e8 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -84,7 +84,31 @@ def update_wheel_platform_tag(): }, libraries=["c10", "torch", "torch_python"], extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ) + ), + CUDAExtension( + "sgl_kernel.ops.custom_reduce_cuda", + [ + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/trt_reduce.cc", + ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], + "cxx": ["-O3"], + }, + libraries=["c10", "torch", "torch_python"], + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), ], cmdclass={"build_ext": BuildExtension}, install_requires=["torch"], diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index edf3921db79..9876d4e5ee2 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,3 +1,8 @@ -from .ops import warp_reduce +from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce -__all__ = ["warp_reduce"] +__all__ = [ + "warp_reduce", + "init_custom_reduce", + "custom_dispose", + "custom_reduce", +] diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc new file mode 100644 index 00000000000..8387a2e5ad7 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc @@ -0,0 +1,15 @@ +#include + +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); + m.def("dispose", &dispose, "dispose custom allreduce meta"); + m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu new file mode 100644 index 00000000000..cedd1c5268b --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -0,0 +1,326 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "trt_reduce_internal.cuh" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) +{ +#if __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) +{ + uint32_t flag; +#if __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#else + asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#endif + return flag; +} + +namespace trt_llm { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Type Converter that packs data format to 128 bits data type +// +using PackedFloat = union { + int4 packed; + float unpacked[4]; +}; + +using PackedHalf = union { + int4 packed; + half2 unpacked[4]; +}; + +template +struct PackedOn16Bytes {}; + +template <> +struct PackedOn16Bytes { + using Type = PackedFloat; +}; + +template <> +struct PackedOn16Bytes { + using Type = PackedHalf; +}; + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +using PackedBFloat16 = union { + int4 packed; + __nv_bfloat162 unpacked[4]; +}; + +template <> +struct PackedOn16Bytes<__nv_bfloat16> { + using Type = PackedBFloat16; +}; +#endif + +// add two 128b data +template +inline __device__ int4 add128b(T &a, T &b) { + T c; + c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; + c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; + c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; + c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; + return c.packed; +} + +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx) +{ + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) + { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + size_t offset = (flag % 2) ? world_size : 0; + + if (bidx == 0) + { + st_flag_release(flag, signals[tidx] + offset + local_rank); + } + + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) + { + } + } + + __syncthreads(); +} + + +template /* COPY_INPUT = false, PUSH_MODE = false */ +static __global__ void oneShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start four blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // GPU 0 | B0 | B1 | B2 | B3 | + // GPU 1 | B0 | B1 | B2 | B3 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) + // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunk is it responsible for into all other GPUs: + // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] + // 2. block sync so the block is shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = 16 / sizeof(T); + + // Packed data type for comms + using PackedStruct = typename PackedOn16Bytes::Type; + + // The source pointers. Distributed round-robin for the different warps. + T const *buffers[RANKS_PER_NODE]; + + // Start and end offsets of the thread + size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; + size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + + multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedStruct vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + } + + // Sum the values from the different ranks. + PackedStruct sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); + } + + // Store to the destination buffer. + *reinterpret_cast(&reinterpret_cast( + params.local_output_buffer_ptr)[iter_offset]) = sums.packed; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int divUp(int a, int b) { + return (a + b - 1) / b; +} + +inline int roundUp(int a, int n) { + return divUp(a, n) * n; +} + +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, + AllReduceParams ¶ms, + size_t elts_per_thread) { + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + assert(params.elts_total % elts_per_thread == 0); + size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE); + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread); + params.elts_per_rank = params.elts_total; + break; + } + default: + assert(false && "Algorithm not supported here."); + } + + return std::make_tuple(blocks_per_grid, threads_per_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void dispatchARKernels(AllReduceStrategyType algo, + AllReduceParams ¶m, + int blocks_per_grid, + int threads_per_block, + cudaStream_t stream) { + oneShotAllReduceKernel + <<>>(param); +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams ¶m, + AllReduceStrategyType strat, + cudaStream_t stream) { + + void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); + void* local_inp_buffer = param.local_input_buffer_ptr; + CHECKCUDA(cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, + cudaMemcpyDeviceToDevice, stream)); + + assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); + auto last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + printf("cuda error: %s\n", cudaGetErrorString(last_error)); + assert(false && "Error before launching the kernel"); + } + + size_t elts_per_thread = 16 / sizeof(T); + auto [blocks_per_grid, threads_per_block] = + kernelLaunchConfig(strat, param, elts_per_thread); + switch (param.ranks_per_node) { + case 2: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 4: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 6: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 8: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + default: + break; + } + last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + printf("cuda error: %s\n", cudaGetErrorString(last_error)); + assert(false && "Error after launching the kernel"); + } +} + +void trtCustomAllReduce(AllReduceParams ¶ms, + at::ScalarType data_type, + AllReduceStrategyType strat, + cudaStream_t stream) { + if (params.elts_total == 0) { + return; + } + + switch (data_type) + { + case at::ScalarType::Float: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; + case at::ScalarType::Half: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: + invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); + break; +#endif + default: + assert(false && "Unsupported data type"); + } +} +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh new file mode 100644 index 00000000000..5c937143b0f --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -0,0 +1,111 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#define FatalError(s) \ + do { \ + std::stringstream _where, _message; \ + _where << __FILE__ << ':' << __LINE__; \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + std::cerr << _message.str() << "\nAborting...\n"; \ + assert(false); \ + exit(1); \ + } while (0) + +#define CHECKCUDA(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace trt_llm { +constexpr size_t WARP_SIZE = 32; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; +constexpr size_t MAX_RANKS_PER_NODE = 8; +constexpr size_t DEFAULT_BLOCK_SIZE = 1024; + +enum class AllReduceStrategyType : int8_t { + RING = 0, + ONESHOT = 1, + TWOSHOT = 2, + AUTO = 3, +}; + +struct AllReduceParams { + size_t elts_size; + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t ranks_per_node, rank, local_rank; + uint32_t barrier_flag; + uint32_t *peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; + uint32_t *peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; + void *peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + void *local_input_buffer_ptr; + void *local_output_buffer_ptr; +}; + +inline size_t GetMaxRequiredWorkspaceSize(int world_size) { + if (world_size <= 2) { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +inline AllReduceStrategyType SelectImplementation(size_t message_size, + int world_size) { + const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); + + if (message_size > maxWorkspaceSize) { + assert(false && "Custom allreduce do not ring currently"); + return AllReduceStrategyType::RING; + } + + if (world_size <= 2) { + return AllReduceStrategyType::ONESHOT; + } + + if (world_size <= 4) { + if (message_size < 1 * 1000 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; + } + + if (message_size < 500 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; +} + +void trtCustomAllReduce(AllReduceParams ¶ms, + at::ScalarType data_type, + AllReduceStrategyType strat, + cudaStream_t stream); + +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu new file mode 100644 index 00000000000..79ed196f023 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -0,0 +1,107 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h + +#include "trt_reduce_internal.cuh" + +#include +#include +#include +#include +#include + +using namespace trt_llm; + +using fptr_t = int64_t; + +class AllReduceMeta { +public: + AllReduceMeta(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { + this->rank_id = (int) rank_id; + this->world_size = (int)world_size; + this->buffers = buffers; + this->barrier_in = barrier_in; + this->barrier_out = barrier_out; + } + +public: + int world_size; + int rank_id; + std::vector buffers; + std::vector barrier_in; + std::vector barrier_out; + int barrier_flag = 1; +}; + +// Get the number of bits for a given data type. +inline int get_bits(at::ScalarType dtype) { + switch (dtype) { + case at::ScalarType::Float: + return 32; + case at::ScalarType::Half: + case at::ScalarType::BFloat16: + return 16; + default: + assert(false && "Unsupported data type"); + } +} + +// Check if customized all-reduce kernels can be applied. +inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) { + // The customized all-reduce kernel has the following requirement(s). + return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; +} + +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); + return (fptr_t)m; +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + + +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + auto num_elements = inp.numel(); + auto dtype = inp.scalar_type(); + AllReduceStrategyType strategy = SelectImplementation( + num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); + + // should be gurantee in python code + assert(strategy == AllReduceStrategyType::ONESHOT); + assert(CanApplyCustomAllReduce(num_elements, dtype)); + + // Initialize the all-reduce kernel arguments. + int world_size = m->world_size; + + AllReduceParams params; + params.ranks_per_node = world_size; + params.rank = m->rank_id; + params.local_rank = m->rank_id; + params.local_input_buffer_ptr = inp.data_ptr(); + params.local_output_buffer_ptr = out.data_ptr(); + params.elts_total = inp.numel(); + params.elts_size = inp.element_size(); + params.barrier_flag = ++(m->barrier_flag); + + for (int i = 0; i < world_size; ++i) { + params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); + } + + auto data_type = out.scalar_type(); + trtCustomAllReduce(params, data_type, strategy, stream); +} diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 21870032e5a..7a3ceb2bd53 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,5 +1,20 @@ +from .custom_reduce_cuda import all_reduce as _all_reduce +from .custom_reduce_cuda import dispose as _dispose +from .custom_reduce_cuda import init_custom_ar as _init_custom_ar from .warp_reduce_cuda import reduce as _reduce def warp_reduce(input_tensor): return _reduce(input_tensor) + + +def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): + return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) + + +def custom_dispose(fa): + _dispose(fa) + + +def custom_reduce(fa, inp, out): + _all_reduce(fa, inp, out) From 9bc795dab4907cda6db778a26bf7f2b54cf02078 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 14 Dec 2024 23:34:48 +0800 Subject: [PATCH 2/8] modify sgl-kernel version --- sgl-kernel/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index ae20bcd8baa..a9111119807 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3.post4" +version = "0.0.2.post5" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" From 3a0b36a5965e00b75e5fe15bad57a010e75e7ba7 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 14 Dec 2024 23:37:21 +0800 Subject: [PATCH 3/8] format code --- sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc | 6 +- .../sgl-kernel/csrc/trt_reduce_internal.cu | 160 +++++++----------- .../sgl-kernel/csrc/trt_reduce_internal.cuh | 53 +++--- .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 35 ++-- 4 files changed, 104 insertions(+), 150 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc index 8387a2e5ad7..4d8f732af3e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc @@ -1,10 +1,8 @@ #include using fptr_t = int64_t; -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, - const std::vector& buffers, - const std::vector& barrier_in, - const std::vector& barrier_out); +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out); void dispose(fptr_t _fa); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index cedd1c5268b..7c17e71a9e2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -1,4 +1,5 @@ -// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu /* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * @@ -15,39 +16,29 @@ * limitations under the License. */ +#include +#include + #include #include #include #include #include -#include -#include #include "trt_reduce_internal.cuh" //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) -{ -#if __CUDA_ARCH__ >= 700 - asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); -#else - __threadfence_system(); - asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); -#endif +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); } //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) -{ - uint32_t flag; -#if __CUDA_ARCH__ >= 700 - asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); -#else - asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); -#endif - return flag; +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; } namespace trt_llm { @@ -92,7 +83,7 @@ struct PackedOn16Bytes<__nv_bfloat16> { // add two 128b data template -inline __device__ int4 add128b(T &a, T &b) { +inline __device__ int4 add128b(T& a, T& b) { T c; c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; @@ -101,38 +92,29 @@ inline __device__ int4 add128b(T &a, T &b) { return c.packed; } -__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, - uint32_t const flag, - size_t const local_rank, - size_t const world_size, - int const tidx, - int const bidx) -{ - // After this function, at least one block in each GPU has reached the barrier - if (tidx < world_size) - { - // we can think of signals having the shape [world_size, world_size] - // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension - - // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers - size_t offset = (flag % 2) ? world_size : 0; - - if (bidx == 0) - { - st_flag_release(flag, signals[tidx] + offset + local_rank); - } - - // All blocks check that corresponding block 0 on other GPUs have set the flag - // No deadlock because block #0 is always the first block started - uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; - while (ld_flag_acquire(peer_barrier_d) != flag) - { - } +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, + size_t const world_size, int const tidx, int const bidx) { + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + size_t offset = (flag % 2) ? world_size : 0; + + if (bidx == 0) { + st_flag_release(flag, signals[tidx] + offset + local_rank); } - __syncthreads(); -} + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } + __syncthreads(); +} template /* COPY_INPUT = false, PUSH_MODE = false */ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { @@ -170,7 +152,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { using PackedStruct = typename PackedOn16Bytes::Type; // The source pointers. Distributed round-robin for the different warps. - T const *buffers[RANKS_PER_NODE]; + T const* buffers[RANKS_PER_NODE]; // Start and end offsets of the thread size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; @@ -178,7 +160,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { int rank = (params.local_rank + ii) % RANKS_PER_NODE; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); } multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); @@ -189,7 +171,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { PackedStruct vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); } // Sum the values from the different ranks. @@ -203,8 +185,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { } // Store to the destination buffer. - *reinterpret_cast(&reinterpret_cast( - params.local_output_buffer_ptr)[iter_offset]) = sums.packed; + *reinterpret_cast(&reinterpret_cast(params.local_output_buffer_ptr)[iter_offset]) = sums.packed; } } @@ -215,12 +196,10 @@ inline int divUp(int a, int b) { } inline int roundUp(int a, int n) { - return divUp(a, n) * n; + return divUp(a, n) * n; } -std::tuple kernelLaunchConfig(AllReduceStrategyType algo, - AllReduceParams ¶ms, - size_t elts_per_thread) { +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) { int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; switch (algo) { case AllReduceStrategyType::ONESHOT: { @@ -242,24 +221,17 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, //////////////////////////////////////////////////////////////////////////////////////////////////// template -void dispatchARKernels(AllReduceStrategyType algo, - AllReduceParams ¶m, - int blocks_per_grid, - int threads_per_block, +void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream) { - oneShotAllReduceKernel - <<>>(param); + oneShotAllReduceKernel<<>>(param); } template -void invokeOneOrTwoShotAllReduceKernel(AllReduceParams ¶m, - AllReduceStrategyType strat, - cudaStream_t stream) { - +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); void* local_inp_buffer = param.local_input_buffer_ptr; - CHECKCUDA(cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, - cudaMemcpyDeviceToDevice, stream)); + CHECKCUDA( + cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); auto last_error = cudaGetLastError(); @@ -269,24 +241,19 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams ¶m, } size_t elts_per_thread = 16 / sizeof(T); - auto [blocks_per_grid, threads_per_block] = - kernelLaunchConfig(strat, param, elts_per_thread); + auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); switch (param.ranks_per_node) { case 2: - dispatchARKernels( - strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 4: - dispatchARKernels( - strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 6: - dispatchARKernels( - strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 8: - dispatchARKernels( - strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; default: break; @@ -298,29 +265,26 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams ¶m, } } -void trtCustomAllReduce(AllReduceParams ¶ms, - at::ScalarType data_type, - AllReduceStrategyType strat, +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) { if (params.elts_total == 0) { return; } - switch (data_type) - { - case at::ScalarType::Float: - invokeOneOrTwoShotAllReduceKernel(params, strat, stream); - break; - case at::ScalarType::Half: - invokeOneOrTwoShotAllReduceKernel(params, strat, stream); - break; + switch (data_type) { + case at::ScalarType::Float: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; + case at::ScalarType::Half: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - case at::ScalarType::BFloat16: - invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); - break; + case at::ScalarType::BFloat16: + invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); + break; #endif - default: - assert(false && "Unsupported data type"); + default: + assert(false && "Unsupported data type"); } } -} +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 5c937143b0f..2d2767e6794 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -1,4 +1,5 @@ -// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp /* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * @@ -20,24 +21,23 @@ #include #include -#define FatalError(s) \ - do { \ - std::stringstream _where, _message; \ - _where << __FILE__ << ':' << __LINE__; \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ - std::cerr << _message.str() << "\nAborting...\n"; \ - assert(false); \ - exit(1); \ +#define FatalError(s) \ + do { \ + std::stringstream _where, _message; \ + _where << __FILE__ << ':' << __LINE__; \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + std::cerr << _message.str() << "\nAborting...\n"; \ + assert(false); \ + exit(1); \ } while (0) -#define CHECKCUDA(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CHECKCUDA(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) namespace trt_llm { @@ -61,11 +61,11 @@ struct AllReduceParams { size_t rank_offset; size_t ranks_per_node, rank, local_rank; uint32_t barrier_flag; - uint32_t *peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; - uint32_t *peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; - void *peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; - void *local_input_buffer_ptr; - void *local_output_buffer_ptr; + uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; + uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; + void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + void* local_input_buffer_ptr; + void* local_output_buffer_ptr; }; inline size_t GetMaxRequiredWorkspaceSize(int world_size) { @@ -75,8 +75,7 @@ inline size_t GetMaxRequiredWorkspaceSize(int world_size) { return 8 * 1000 * 1000; } -inline AllReduceStrategyType SelectImplementation(size_t message_size, - int world_size) { +inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) { const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); if (message_size > maxWorkspaceSize) { @@ -103,9 +102,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, return AllReduceStrategyType::TWOSHOT; } -void trtCustomAllReduce(AllReduceParams ¶ms, - at::ScalarType data_type, - AllReduceStrategyType strat, +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream); -} +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index 79ed196f023..2a2dcebc89e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -1,31 +1,30 @@ // reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h -#include "trt_reduce_internal.cuh" +#include -#include #include -#include #include -#include +#include +#include + +#include "trt_reduce_internal.cuh" using namespace trt_llm; using fptr_t = int64_t; class AllReduceMeta { -public: - AllReduceMeta(int64_t rank_id, int64_t world_size, - const std::vector& buffers, - const std::vector& barrier_in, - const std::vector& barrier_out) { - this->rank_id = (int) rank_id; + public: + AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out) { + this->rank_id = (int)rank_id; this->world_size = (int)world_size; this->buffers = buffers; this->barrier_in = barrier_in; this->barrier_out = barrier_out; } -public: + public: int world_size; int rank_id; std::vector buffers; @@ -53,10 +52,8 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, - const std::vector& buffers, - const std::vector& barrier_in, - const std::vector& barrier_out) { +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, + const std::vector& barrier_in, const std::vector& barrier_out) { auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); return (fptr_t)m; } @@ -66,14 +63,12 @@ void dispose(fptr_t _fa) { delete fa; } - void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { AllReduceMeta* m = reinterpret_cast(_fa); auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto num_elements = inp.numel(); auto dtype = inp.scalar_type(); - AllReduceStrategyType strategy = SelectImplementation( - num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); + AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); // should be gurantee in python code assert(strategy == AllReduceStrategyType::ONESHOT); @@ -96,10 +91,10 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); } for (int i = 0; i < world_size; ++i) { - params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); + params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); } for (int i = 0; i < world_size; ++i) { - params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); + params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); } auto data_type = out.scalar_type(); From ce283ff58c2b5fbe8fac7b0827e9d86b1a4e1ae4 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sun, 15 Dec 2024 00:28:38 +0800 Subject: [PATCH 4/8] add utils.hpp for basic check --- .../sgl-kernel/csrc/trt_reduce_internal.cu | 14 ++------ .../sgl-kernel/csrc/trt_reduce_internal.cuh | 20 +---------- sgl-kernel/src/sgl-kernel/csrc/utils.hpp | 35 +++++++++++++++++++ sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc | 9 ++--- 4 files changed, 41 insertions(+), 37 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/utils.hpp diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 7c17e71a9e2..04393c8e716 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -230,15 +230,11 @@ template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); void* local_inp_buffer = param.local_input_buffer_ptr; - CHECKCUDA( + CHECK_CUDA_SUCCESS( cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); - auto last_error = cudaGetLastError(); - if (last_error != cudaSuccess) { - printf("cuda error: %s\n", cudaGetErrorString(last_error)); - assert(false && "Error before launching the kernel"); - } + CHECK_CUDA_SUCCESS(cudaGetLastError()); size_t elts_per_thread = 16 / sizeof(T); auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); @@ -258,11 +254,7 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy default: break; } - last_error = cudaGetLastError(); - if (last_error != cudaSuccess) { - printf("cuda error: %s\n", cudaGetErrorString(last_error)); - assert(false && "Error after launching the kernel"); - } + CHECK_CUDA_SUCCESS(cudaGetLastError()); } void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 2d2767e6794..46f196c0447 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -20,25 +20,7 @@ #include #include #include - -#define FatalError(s) \ - do { \ - std::stringstream _where, _message; \ - _where << __FILE__ << ':' << __LINE__; \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ - std::cerr << _message.str() << "\nAborting...\n"; \ - assert(false); \ - exit(1); \ - } while (0) - -#define CHECKCUDA(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#include "utils.hpp" namespace trt_llm { constexpr size_t WARP_SIZE = 32; diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp new file mode 100644 index 00000000000..eefbee8ae56 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp @@ -0,0 +1,35 @@ +#pragma once +#include +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc index efc2f0cd951..6cc3ae152ca 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc @@ -1,15 +1,10 @@ #include +#include "utils.hpp" torch::Tensor warp_reduce_cuda(torch::Tensor input); -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - torch::Tensor warp_reduce(torch::Tensor input) { - CHECK_INPUT(input); + CHECK_CUDA_INPUT(input); return warp_reduce_cuda(input); } From 01fe99eaf6da52d1e5c458ffc1115bfad4073669 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sun, 15 Dec 2024 11:59:58 +0800 Subject: [PATCH 5/8] add test for custom allreduce in sgl-kernel --- sgl-kernel/tests/test_trt_reduce.py | 243 ++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 sgl-kernel/tests/test_trt_reduce.py diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py new file mode 100644 index 00000000000..d2ddfcfb4ec --- /dev/null +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -0,0 +1,243 @@ +import ctypes +import logging +import os +import random +import socket +import time +import unittest +from typing import Any, List, Optional, Union + +import ray +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm import _custom_ops as vllm_ops + +from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + +logger = logging.getLogger(__name__) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + cls.test_sizes = [512, 4096, 32768, 262144, 2097152] + cls.world_sizes = [2, 4, 6, 8] + + @staticmethod + def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None + ) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None + ) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + def test_correctness(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.correctness) + + def test_performance(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.performance) + + def init_custom_allreduce(self, rank, world_size, group): + import sgl_kernel + + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) + self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + + self.custom_ptr = sgl_kernel.ops.init_custom_reduce( + rank, + world_size, + self.buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + + def custom_allreduce(self, inp, out): + import sgl_kernel + + sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out) + + def free_custom_allreduce(self, group): + import sgl_kernel + + self.free_shared_buffer(self.buffer_ptrs, group) + self.free_shared_buffer(self.barrier_in_ptrs, group) + self.free_shared_buffer(self.barrier_out_ptrs, group) + sgl_kernel.ops.custom_dispose(self.custom_ptr) + + def init_vllm_allreduce(self, rank, group): + self.vllm_rank = rank + self.vllm_max_size = 8 * 1024 * 1024 + self.vllm_meta_ptrs = self.create_shared_buffer( + vllm_ops.meta_size() + self.vllm_max_size, group=group + ) + self.vllm_buffer_ptrs = self.create_shared_buffer( + self.vllm_max_size, group=group + ) + self.vllm_rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) + self.vllm_ptr = vllm_ops.init_custom_ar( + self.vllm_meta_ptrs, self.vllm_rank_data, rank, True + ) + vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs) + + def vllm_allreduce(self, inp, out): + vllm_ops.all_reduce( + self.vllm_ptr, + inp, + out, + self.vllm_buffer_ptrs[self.vllm_rank], + self.vllm_max_size, + ) + + def free_vllm_allreduce(self, group): + vllm_ops.dispose(self.vllm_ptr) + self.free_shared_buffer(self.vllm_meta_ptrs, group) + self.free_shared_buffer(self.vllm_buffer_ptrs, group) + + @staticmethod + def init_distributed_env(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = [i for i in range(world_size)] + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = torch.distributed.new_group(ranks, backend="gloo") + return group + + # compare result with torch.distributed + @ray.remote(num_gpus=1, max_calls=1) + def correctness(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + test_loop = 10 + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + self.custom_allreduce(inp1, out1) + + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + self.free_custom_allreduce(group) + + # compare performance with vllm + @ray.remote(num_gpus=1, max_calls=1) + def performance(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_vllm_allreduce(rank, group) + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + for sz in self.test_sizes: + inp1 = torch.randint( + 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + test_loop = 5000 + start = time.time() + for _ in range(test_loop): + self.custom_allreduce(inp1, out1) + elapse_custom = time.time() - start + + start = time.time() + for _ in range(test_loop): + self.vllm_allreduce(inp1, out1) + elapse_vllm = time.time() - start + + if rank == 0: + logger.warning( + f"test_size = {sz}, world_size = {world_size}, " + f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}us," + f"custom time = {elapse_custom * 1000 / test_loop:.4f}us" + ) + + self.free_custom_allreduce(group) + self.free_vllm_allreduce(group) + + +if __name__ == "__main__": + unittest.main() From a4eb61771201069d1e43fa171157fb37856bbb53 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sun, 15 Dec 2024 12:31:44 +0800 Subject: [PATCH 6/8] fix different world_size test_sizes --- sgl-kernel/tests/test_trt_reduce.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py index d2ddfcfb4ec..09a11908018 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -56,7 +56,12 @@ class TestCustomAllReduce(unittest.TestCase): @classmethod def setUpClass(cls): random.seed(42) - cls.test_sizes = [512, 4096, 32768, 262144, 2097152] + cls.test_sizes = { + 2: [512, 4096, 32768, 262144, 2097152], + 4: [512, 4096, 32768, 131072], + 6: [512, 4096, 32768, 65536], + 8: [512, 4096, 32768, 65536], + } cls.world_sizes = [2, 4, 6, 8] @staticmethod @@ -190,7 +195,7 @@ def correctness(self, world_size, rank, distributed_init_port): self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) test_loop = 10 - for sz in self.test_sizes: + for sz in self.test_sizes[world_size]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: for _ in range(test_loop): inp1 = torch.randint( @@ -212,7 +217,7 @@ def performance(self, world_size, rank, distributed_init_port): self.init_vllm_allreduce(rank, group) self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) - for sz in self.test_sizes: + for sz in self.test_sizes[world_size]: inp1 = torch.randint( 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() ) From f3afd7b610fb52091f15d80eab3030ad937d6728 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sat, 14 Dec 2024 20:33:27 -0800 Subject: [PATCH 7/8] upd --- sgl-kernel/Makefile | 2 +- .../sgl-kernel/csrc/trt_reduce_internal.cuh | 1 + sgl-kernel/src/sgl-kernel/csrc/utils.hpp | 25 ++++++++++--------- sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc | 1 + 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index ec86d684877..452053cb214 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -19,4 +19,4 @@ test: @pytest tests/ format: - @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black + @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 46f196c0447..01652a22a86 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -20,6 +20,7 @@ #include #include #include + #include "utils.hpp" namespace trt_llm { diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp index eefbee8ae56..bbdc6311be9 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp @@ -1,7 +1,8 @@ #pragma once -#include #include +#include + struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. @@ -17,19 +18,19 @@ struct cuda_error : public std::runtime_error { cuda_error(std::string const& message) : cuda_error{message.c_str()} {} }; -#define CHECK_CUDA_SUCCESS(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - std::stringstream _message; \ - auto s = cudaGetErrorString(e); \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ - throw cuda_error(_message.str()); \ - } \ +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ } while (0) #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CUDA_INPUT(x) \ - CHECK_IS_CUDA(x); \ +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ CHECK_IS_CONTIGUOUS(x) diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc index 6cc3ae152ca..379b4cc15bf 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc @@ -1,4 +1,5 @@ #include + #include "utils.hpp" torch::Tensor warp_reduce_cuda(torch::Tensor input); From bf3ed13de66eab3c008e007f0c0a4653679c99b4 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sat, 14 Dec 2024 21:13:46 -0800 Subject: [PATCH 8/8] upd --- sgl-kernel/CMakeLists.txt | 66 ++++++++++++++++++++++++++++----------- sgl-kernel/Makefile | 2 +- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index c635b75c348..adb81fa2b0f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -1,47 +1,75 @@ cmake_minimum_required(VERSION 3.18) project(sgl-kernel LANGUAGES CXX CUDA) +# Basic settings set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) -find_package(PythonInterp 3 REQUIRED) -find_package(PythonLibs 3 REQUIRED) +# Set CUDA architectures +set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90") +message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +# Find PyTorch execute_process( - COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" OUTPUT_VARIABLE TORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE ) - -message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") -message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}") - list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") find_package(Torch REQUIRED) -include_directories(${PYTHON_INCLUDE_DIRS}) - +# Warp Reduce library add_library(warp_reduce SHARED src/sgl-kernel/csrc/warp_reduce.cc src/sgl-kernel/csrc/warp_reduce_kernel.cu ) -target_include_directories(warp_reduce PRIVATE - ${CUDA_INCLUDE_DIRS} - ${TORCH_INCLUDE_DIRS} +target_include_directories(warp_reduce + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} ) -target_link_libraries(warp_reduce PRIVATE - ${TORCH_LIBRARIES} - ${PYTHON_LIBRARIES} +target_link_libraries(warp_reduce + PRIVATE + ${TORCH_LIBRARIES} + Python3::Python ) -set_target_properties(warp_reduce PROPERTIES - CUDA_SEPARABLE_COMPILATION ON +# TRT Reduce library +add_library(trt_reduce SHARED + src/sgl-kernel/csrc/trt_reduce.cc + src/sgl-kernel/csrc/trt_reduce_internal.cu + src/sgl-kernel/csrc/trt_reduce_kernel.cu ) + +target_include_directories(trt_reduce + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} +) + +target_link_libraries(trt_reduce + PRIVATE + ${TORCH_LIBRARIES} + Python3::Python +) + +# Set common properties for both libraries +foreach(target warp_reduce trt_reduce) + set_target_properties(${target} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + POSITION_INDEPENDENT_CODE ON + CUDA_RESOLVE_DEVICE_SYMBOLS ON + PREFIX "" + SUFFIX ".so" + ) +endforeach() diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 452053cb214..7a041b1ed40 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -10,7 +10,7 @@ install: @pip install -e . build: - @python3 setup.py bdist_wheel + @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel clean: @rm -rf build dist *.egg-info