From 69df322fd6cb5c9d4ba9ed9aa8174e44c33cf6bc Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 14 Dec 2024 22:15:20 +0800 Subject: [PATCH] sgl-kernel support tensorrt llm custom allreduce --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/setup.py | 26 +- sgl-kernel/src/sgl-kernel/__init__.py | 9 +- sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc | 15 + .../sgl-kernel/csrc/trt_reduce_internal.cu | 326 ++++++++++++++++++ .../sgl-kernel/csrc/trt_reduce_internal.cuh | 111 ++++++ .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 107 ++++++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 15 + 8 files changed, 607 insertions(+), 4 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 4fbdd9dae23..ae20bcd8baa 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post4" +version = "0.0.3.post4" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 724c634bc25..d81e9da00e8 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -84,7 +84,31 @@ def update_wheel_platform_tag(): }, libraries=["c10", "torch", "torch_python"], extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ) + ), + CUDAExtension( + "sgl_kernel.ops.custom_reduce_cuda", + [ + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/trt_reduce.cc", + ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], + "cxx": ["-O3"], + }, + libraries=["c10", "torch", "torch_python"], + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), ], cmdclass={"build_ext": BuildExtension}, install_requires=["torch"], diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index edf3921db79..9876d4e5ee2 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,3 +1,8 @@ -from .ops import warp_reduce +from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce -__all__ = ["warp_reduce"] +__all__ = [ + "warp_reduce", + "init_custom_reduce", + "custom_dispose", + "custom_reduce", +] diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc new file mode 100644 index 00000000000..8387a2e5ad7 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc @@ -0,0 +1,15 @@ +#include + +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); + m.def("dispose", &dispose, "dispose custom allreduce meta"); + m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu new file mode 100644 index 00000000000..cedd1c5268b --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -0,0 +1,326 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "trt_reduce_internal.cuh" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) +{ +#if __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) +{ + uint32_t flag; +#if __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#else + asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#endif + return flag; +} + +namespace trt_llm { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Type Converter that packs data format to 128 bits data type +// +using PackedFloat = union { + int4 packed; + float unpacked[4]; +}; + +using PackedHalf = union { + int4 packed; + half2 unpacked[4]; +}; + +template +struct PackedOn16Bytes {}; + +template <> +struct PackedOn16Bytes { + using Type = PackedFloat; +}; + +template <> +struct PackedOn16Bytes { + using Type = PackedHalf; +}; + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +using PackedBFloat16 = union { + int4 packed; + __nv_bfloat162 unpacked[4]; +}; + +template <> +struct PackedOn16Bytes<__nv_bfloat16> { + using Type = PackedBFloat16; +}; +#endif + +// add two 128b data +template +inline __device__ int4 add128b(T &a, T &b) { + T c; + c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; + c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; + c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; + c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; + return c.packed; +} + +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx) +{ + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) + { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + size_t offset = (flag % 2) ? world_size : 0; + + if (bidx == 0) + { + st_flag_release(flag, signals[tidx] + offset + local_rank); + } + + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) + { + } + } + + __syncthreads(); +} + + +template /* COPY_INPUT = false, PUSH_MODE = false */ +static __global__ void oneShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start four blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // GPU 0 | B0 | B1 | B2 | B3 | + // GPU 1 | B0 | B1 | B2 | B3 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) + // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunk is it responsible for into all other GPUs: + // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] + // 2. block sync so the block is shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = 16 / sizeof(T); + + // Packed data type for comms + using PackedStruct = typename PackedOn16Bytes::Type; + + // The source pointers. Distributed round-robin for the different warps. + T const *buffers[RANKS_PER_NODE]; + + // Start and end offsets of the thread + size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; + size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + } + + multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedStruct vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + } + + // Sum the values from the different ranks. + PackedStruct sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; + sums.packed = add128b(sums, vals[ii]); + } + + // Store to the destination buffer. + *reinterpret_cast(&reinterpret_cast( + params.local_output_buffer_ptr)[iter_offset]) = sums.packed; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int divUp(int a, int b) { + return (a + b - 1) / b; +} + +inline int roundUp(int a, int n) { + return divUp(a, n) * n; +} + +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, + AllReduceParams ¶ms, + size_t elts_per_thread) { + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + assert(params.elts_total % elts_per_thread == 0); + size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE); + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread); + params.elts_per_rank = params.elts_total; + break; + } + default: + assert(false && "Algorithm not supported here."); + } + + return std::make_tuple(blocks_per_grid, threads_per_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void dispatchARKernels(AllReduceStrategyType algo, + AllReduceParams ¶m, + int blocks_per_grid, + int threads_per_block, + cudaStream_t stream) { + oneShotAllReduceKernel + <<>>(param); +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams ¶m, + AllReduceStrategyType strat, + cudaStream_t stream) { + + void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); + void* local_inp_buffer = param.local_input_buffer_ptr; + CHECKCUDA(cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, + cudaMemcpyDeviceToDevice, stream)); + + assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); + auto last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + printf("cuda error: %s\n", cudaGetErrorString(last_error)); + assert(false && "Error before launching the kernel"); + } + + size_t elts_per_thread = 16 / sizeof(T); + auto [blocks_per_grid, threads_per_block] = + kernelLaunchConfig(strat, param, elts_per_thread); + switch (param.ranks_per_node) { + case 2: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 4: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 6: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 8: + dispatchARKernels( + strat, param, blocks_per_grid, threads_per_block, stream); + break; + default: + break; + } + last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + printf("cuda error: %s\n", cudaGetErrorString(last_error)); + assert(false && "Error after launching the kernel"); + } +} + +void trtCustomAllReduce(AllReduceParams ¶ms, + at::ScalarType data_type, + AllReduceStrategyType strat, + cudaStream_t stream) { + if (params.elts_total == 0) { + return; + } + + switch (data_type) + { + case at::ScalarType::Float: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; + case at::ScalarType::Half: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: + invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); + break; +#endif + default: + assert(false && "Unsupported data type"); + } +} +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh new file mode 100644 index 00000000000..5c937143b0f --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -0,0 +1,111 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#define FatalError(s) \ + do { \ + std::stringstream _where, _message; \ + _where << __FILE__ << ':' << __LINE__; \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + std::cerr << _message.str() << "\nAborting...\n"; \ + assert(false); \ + exit(1); \ + } while (0) + +#define CHECKCUDA(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace trt_llm { +constexpr size_t WARP_SIZE = 32; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; +constexpr size_t MAX_RANKS_PER_NODE = 8; +constexpr size_t DEFAULT_BLOCK_SIZE = 1024; + +enum class AllReduceStrategyType : int8_t { + RING = 0, + ONESHOT = 1, + TWOSHOT = 2, + AUTO = 3, +}; + +struct AllReduceParams { + size_t elts_size; + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t ranks_per_node, rank, local_rank; + uint32_t barrier_flag; + uint32_t *peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; + uint32_t *peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; + void *peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + void *local_input_buffer_ptr; + void *local_output_buffer_ptr; +}; + +inline size_t GetMaxRequiredWorkspaceSize(int world_size) { + if (world_size <= 2) { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +inline AllReduceStrategyType SelectImplementation(size_t message_size, + int world_size) { + const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); + + if (message_size > maxWorkspaceSize) { + assert(false && "Custom allreduce do not ring currently"); + return AllReduceStrategyType::RING; + } + + if (world_size <= 2) { + return AllReduceStrategyType::ONESHOT; + } + + if (world_size <= 4) { + if (message_size < 1 * 1000 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; + } + + if (message_size < 500 * 1000) { + return AllReduceStrategyType::ONESHOT; + } + assert(false && "Custom allreduce do not twoshot currently"); + return AllReduceStrategyType::TWOSHOT; +} + +void trtCustomAllReduce(AllReduceParams ¶ms, + at::ScalarType data_type, + AllReduceStrategyType strat, + cudaStream_t stream); + +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu new file mode 100644 index 00000000000..79ed196f023 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -0,0 +1,107 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h + +#include "trt_reduce_internal.cuh" + +#include +#include +#include +#include +#include + +using namespace trt_llm; + +using fptr_t = int64_t; + +class AllReduceMeta { +public: + AllReduceMeta(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { + this->rank_id = (int) rank_id; + this->world_size = (int)world_size; + this->buffers = buffers; + this->barrier_in = barrier_in; + this->barrier_out = barrier_out; + } + +public: + int world_size; + int rank_id; + std::vector buffers; + std::vector barrier_in; + std::vector barrier_out; + int barrier_flag = 1; +}; + +// Get the number of bits for a given data type. +inline int get_bits(at::ScalarType dtype) { + switch (dtype) { + case at::ScalarType::Float: + return 32; + case at::ScalarType::Half: + case at::ScalarType::BFloat16: + return 16; + default: + assert(false && "Unsupported data type"); + } +} + +// Check if customized all-reduce kernels can be applied. +inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) { + // The customized all-reduce kernel has the following requirement(s). + return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; +} + +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, + const std::vector& buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); + return (fptr_t)m; +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + + +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + auto num_elements = inp.numel(); + auto dtype = inp.scalar_type(); + AllReduceStrategyType strategy = SelectImplementation( + num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); + + // should be gurantee in python code + assert(strategy == AllReduceStrategyType::ONESHOT); + assert(CanApplyCustomAllReduce(num_elements, dtype)); + + // Initialize the all-reduce kernel arguments. + int world_size = m->world_size; + + AllReduceParams params; + params.ranks_per_node = world_size; + params.rank = m->rank_id; + params.local_rank = m->rank_id; + params.local_input_buffer_ptr = inp.data_ptr(); + params.local_output_buffer_ptr = out.data_ptr(); + params.elts_total = inp.numel(); + params.elts_size = inp.element_size(); + params.barrier_flag = ++(m->barrier_flag); + + for (int i = 0; i < world_size; ++i) { + params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); + } + + auto data_type = out.scalar_type(); + trtCustomAllReduce(params, data_type, strategy, stream); +} diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 21870032e5a..7a3ceb2bd53 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,5 +1,20 @@ +from .custom_reduce_cuda import all_reduce as _all_reduce +from .custom_reduce_cuda import dispose as _dispose +from .custom_reduce_cuda import init_custom_ar as _init_custom_ar from .warp_reduce_cuda import reduce as _reduce def warp_reduce(input_tensor): return _reduce(input_tensor) + + +def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): + return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) + + +def custom_dispose(fa): + _dispose(fa) + + +def custom_reduce(fa, inp, out): + _all_reduce(fa, inp, out)