Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Dec 14, 2024
1 parent 9bc795d commit 3a0b36a
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 150 deletions.
6 changes: 2 additions & 4 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#include <torch/extension.h>

using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);

Expand Down
160 changes: 62 additions & 98 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
Expand All @@ -15,39 +16,29 @@
* limitations under the License.
*/

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "trt_reduce_internal.cuh"

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
__threadfence_system();
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr)
{
uint32_t flag;
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}

namespace trt_llm {
Expand Down Expand Up @@ -92,7 +83,7 @@ struct PackedOn16Bytes<__nv_bfloat16> {

// add two 128b data
template <typename T>
inline __device__ int4 add128b(T &a, T &b) {
inline __device__ int4 add128b(T& a, T& b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
Expand All @@ -101,38 +92,29 @@ inline __device__ int4 add128b(T &a, T &b) {
return c.packed;
}

__inline__ __device__ void multi_gpu_barrier(uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx)
{
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size)
{
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension

// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;

if (bidx == 0)
{
st_flag_release(flag, signals[tidx] + offset + local_rank);
}

// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag)
{
}
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension

// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;

if (bidx == 0) {
st_flag_release(flag, signals[tidx] + offset + local_rank);
}

__syncthreads();
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
}

__syncthreads();
}

template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
Expand Down Expand Up @@ -170,15 +152,15 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
using PackedStruct = typename PackedOn16Bytes<T>::Type;

// The source pointers. Distributed round-robin for the different warps.
T const *buffers[RANKS_PER_NODE];
T const* buffers[RANKS_PER_NODE];

// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
buffers[ii] = reinterpret_cast<T *>(params.peer_comm_buffer_ptrs[rank]);
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}

multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
Expand All @@ -189,7 +171,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const *>(&buffers[ii][iter_offset]);
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
}

// Sum the values from the different ranks.
Expand All @@ -203,8 +185,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}

// Store to the destination buffer.
*reinterpret_cast<int4 *>(&reinterpret_cast<T *>(
params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}

Expand All @@ -215,12 +196,10 @@ inline int divUp(int a, int b) {
}

inline int roundUp(int a, int n) {
return divUp(a, n) * n;
return divUp(a, n) * n;
}

std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo,
AllReduceParams &params,
size_t elts_per_thread) {
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
Expand All @@ -242,24 +221,17 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo,
////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(AllReduceStrategyType algo,
AllReduceParams &param,
int blocks_per_grid,
int threads_per_block,
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
oneShotAllReduceKernel<T, RANKS_PER_NODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}

template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams &param,
AllReduceStrategyType strat,
cudaStream_t stream) {

void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECKCUDA(cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size,
cudaMemcpyDeviceToDevice, stream));
CHECKCUDA(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));

assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
auto last_error = cudaGetLastError();
Expand All @@ -269,24 +241,19 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams &param,
}

size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] =
kernelLaunchConfig(strat, param, elts_per_thread);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(
strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(
strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(
strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(
strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
Expand All @@ -298,29 +265,26 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams &param,
}
}

void trtCustomAllReduce(AllReduceParams &params,
at::ScalarType data_type,
AllReduceStrategyType strat,
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}

switch (data_type)
{
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
switch (data_type) {
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
#endif
default:
assert(false && "Unsupported data type");
default:
assert(false && "Unsupported data type");
}
}
}
} // namespace trt_llm
53 changes: 25 additions & 28 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
Expand All @@ -20,24 +21,23 @@
#include <stdint.h>
#include <torch/all.h>

#define FatalError(s) \
do { \
std::stringstream _where, _message; \
_where << __FILE__ << ':' << __LINE__; \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
std::cerr << _message.str() << "\nAborting...\n"; \
assert(false); \
exit(1); \
#define FatalError(s) \
do { \
std::stringstream _where, _message; \
_where << __FILE__ << ':' << __LINE__; \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
std::cerr << _message.str() << "\nAborting...\n"; \
assert(false); \
exit(1); \
} while (0)

#define CHECKCUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
#define CHECKCUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)

namespace trt_llm {
Expand All @@ -61,11 +61,11 @@ struct AllReduceParams {
size_t rank_offset;
size_t ranks_per_node, rank, local_rank;
uint32_t barrier_flag;
uint32_t *peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t *peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void *peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void *local_input_buffer_ptr;
void *local_output_buffer_ptr;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
};

inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
Expand All @@ -75,8 +75,7 @@ inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
return 8 * 1000 * 1000;
}

inline AllReduceStrategyType SelectImplementation(size_t message_size,
int world_size) {
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);

if (message_size > maxWorkspaceSize) {
Expand All @@ -103,9 +102,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size,
return AllReduceStrategyType::TWOSHOT;
}

void trtCustomAllReduce(AllReduceParams &params,
at::ScalarType data_type,
AllReduceStrategyType strat,
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream);

}
} // namespace trt_llm
Loading

0 comments on commit 3a0b36a

Please sign in to comment.