Skip to content

Commit

Permalink
sgl-kernel support tensorrt llm custom allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Dec 14, 2024
1 parent 2f9bd0f commit 69df322
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post4"
version = "0.0.3.post4"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
26 changes: 25 additions & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,31 @@ def update_wheel_platform_tag():
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
)
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
Expand Down
9 changes: 7 additions & 2 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .ops import warp_reduce
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce

__all__ = ["warp_reduce"]
__all__ = [
"warp_reduce",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
]
15 changes: 15 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/extension.h>

using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
}
326 changes: 326 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "trt_reduce_internal.cuh"

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
__threadfence_system();
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr)
{
uint32_t flag;
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
}

namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////

// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
int4 packed;
float unpacked[4];
};

using PackedHalf = union {
int4 packed;
half2 unpacked[4];
};

template <typename T>
struct PackedOn16Bytes {};

template <>
struct PackedOn16Bytes<float> {
using Type = PackedFloat;
};

template <>
struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
};

template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
#endif

// add two 128b data
template <typename T>
inline __device__ int4 add128b(T &a, T &b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}

__inline__ __device__ void multi_gpu_barrier(uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx)
{
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size)
{
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension

// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;

if (bidx == 0)
{
st_flag_release(flag, signals[tidx] + offset + local_rank);
}

// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag)
{
}
}

__syncthreads();
}


template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]

int const bidx = blockIdx.x;
int const tidx = threadIdx.x;

// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);

// Packed data type for comms
using PackedStruct = typename PackedOn16Bytes<T>::Type;

// The source pointers. Distributed round-robin for the different warps.
T const *buffers[RANKS_PER_NODE];

// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
buffers[ii] = reinterpret_cast<T *>(params.peer_comm_buffer_ptrs[rank]);
}

multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);

// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const *>(&buffers[ii][iter_offset]);
}

// Sum the values from the different ranks.
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
}

// Store to the destination buffer.
*reinterpret_cast<int4 *>(&reinterpret_cast<T *>(
params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline int divUp(int a, int b) {
return (a + b - 1) / b;
}

inline int roundUp(int a, int n) {
return divUp(a, n) * n;
}

std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo,
AllReduceParams &params,
size_t elts_per_thread) {
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
assert(params.elts_total % elts_per_thread == 0);
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
params.elts_per_rank = params.elts_total;
break;
}
default:
assert(false && "Algorithm not supported here.");
}

return std::make_tuple(blocks_per_grid, threads_per_block);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(AllReduceStrategyType algo,
AllReduceParams &param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
oneShotAllReduceKernel<T, RANKS_PER_NODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}

template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams &param,
AllReduceStrategyType strat,
cudaStream_t stream) {

void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECKCUDA(cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size,
cudaMemcpyDeviceToDevice, stream));

assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
auto last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
printf("cuda error: %s\n", cudaGetErrorString(last_error));
assert(false && "Error before launching the kernel");
}

size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] =
kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(
strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(
strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(
strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(
strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
printf("cuda error: %s\n", cudaGetErrorString(last_error));
assert(false && "Error after launching the kernel");
}
}

void trtCustomAllReduce(AllReduceParams &params,
at::ScalarType data_type,
AllReduceStrategyType strat,
cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}

switch (data_type)
{
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
#endif
default:
assert(false && "Unsupported data type");
}
}
}
Loading

0 comments on commit 69df322

Please sign in to comment.