From 300d18e4f313cb01c731a38c92014f406b4aa790 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 25 Dec 2024 11:19:55 -0800 Subject: [PATCH 1/5] upd --- sgl-kernel/setup.py | 22 +++ sgl-kernel/src/sgl-kernel/__init__.py | 9 +- .../sgl-kernel/csrc/moe_align_sum_kernels.cc | 9 ++ .../sgl-kernel/csrc/moe_align_sum_kernels.cu | 143 ++++++++++++++++++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 19 +++ 5 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc create mode 100644 sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index d81e9da00e8..4050321fd2c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -109,6 +109,28 @@ def update_wheel_platform_tag(): libraries=["c10", "torch", "torch_python"], extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], ), + CUDAExtension( + "sgl_kernel.ops.moe_align_block_size", + [ + "src/sgl-kernel/csrc/moe_align_sum_kernels.cu", + ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], + "cxx": ["-O3"], + }, + libraries=["c10", "torch", "torch_python"], + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), ], cmdclass={"build_ext": BuildExtension}, install_requires=["torch"], diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 9876d4e5ee2..b0aa791b196 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,8 +1,15 @@ -from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce +from .ops import ( + custom_dispose, + custom_reduce, + init_custom_reduce, + moe_align_block_size, + warp_reduce, +) __all__ = [ "warp_reduce", "init_custom_reduce", "custom_dispose", "custom_reduce", + "moe_align_block_size", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc b/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc new file mode 100644 index 00000000000..6903a357648 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc @@ -0,0 +1,9 @@ +#include +#include + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); +} \ No newline at end of file diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu new file mode 100644 index 00000000000..f54f8f08339 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu @@ -0,0 +1,143 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu + +#include +#include +// #include +#include +#include + +// #include +#include + +#ifdef USE_ROCM +#include +#endif + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM +#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else +#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif + +#define CEILDIV(x, y) (((x) + (y)-1) / (y)) + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) + int32_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // set dynamic shared mem + auto kernel = moe_align_block_size_kernel; + AT_CUDA_CHECK(DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void*)kernel, shared_mem)); + kernel<<<1, num_thread, shared_mem, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); + }); +} diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 7a3ceb2bd53..1ca551b6bdb 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,6 +1,7 @@ from .custom_reduce_cuda import all_reduce as _all_reduce from .custom_reduce_cuda import dispose as _dispose from .custom_reduce_cuda import init_custom_ar as _init_custom_ar +from .moe_align_block_size import moe_align_block_size as _moe_align_block_size from .warp_reduce_cuda import reduce as _reduce @@ -18,3 +19,21 @@ def custom_dispose(fa): def custom_reduce(fa, inp, out): _all_reduce(fa, inp, out) + + +def moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, +): + _moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) From b5dcc9d03a9ed49766d8f9ab2496213497ef083c Mon Sep 17 00:00:00 2001 From: ispobock Date: Wed, 25 Dec 2024 20:54:39 +0000 Subject: [PATCH 2/5] fix compile --- sgl-kernel/setup.py | 4 +- ...ign_sum_kernels.cu => moe_align_kernel.cu} | 84 ++++++++++++------- .../sgl-kernel/csrc/moe_align_sum_kernels.cc | 9 -- 3 files changed, 56 insertions(+), 41 deletions(-) rename sgl-kernel/src/sgl-kernel/csrc/{moe_align_sum_kernels.cu => moe_align_kernel.cu} (64%) delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 4050321fd2c..5b8da4b153c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -112,7 +112,7 @@ def update_wheel_platform_tag(): CUDAExtension( "sgl_kernel.ops.moe_align_block_size", [ - "src/sgl-kernel/csrc/moe_align_sum_kernels.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", ], extra_compile_args={ "nvcc": [ @@ -123,8 +123,6 @@ def update_wheel_platform_tag(): "-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", ], "cxx": ["-O3"], }, diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu similarity index 64% rename from sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu rename to sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index f54f8f08339..eae14b92ab3 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -2,11 +2,10 @@ #include #include -// #include #include #include +#include -// #include #include #ifdef USE_ROCM @@ -45,17 +44,17 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t } template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel) { +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel, + int32_t* tokens_cnts, + int32_t* cumsum) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) - int32_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } @@ -75,7 +74,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int if (threadIdx.x < num_experts) { tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } @@ -85,7 +85,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } *total_tokens_post_pad = cumsum[num_experts]; } @@ -97,7 +100,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int * blocks and stores the corresponding expert_id for each block. */ if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { expert_ids[i / block_size] = threadIdx.x; } } @@ -117,27 +121,49 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int * processed by the expert with expert_id within the current thread's token * shard. */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - - // set dynamic shared mem - auto kernel = moe_align_block_size_kernel; - AT_CUDA_CHECK(DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void*)kernel, shared_mem)); - kernel<<<1, num_thread, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + const int32_t mem_tokens_cnts = + ((num_experts + 1) * num_experts) * sizeof(int32_t); + const int32_t mem_cumsum = + (num_experts + 1) * sizeof(int32_t); + + // allocate global memory + int32_t* tokens_cnts; + int32_t* cumsum; + cudaMalloc(&tokens_cnts, mem_tokens_cnts); + cudaMalloc(&cumsum, mem_cumsum); + + // set dynamic shared mem + auto kernel = moe_align_block_size_kernel; + kernel<<<1, num_thread, 0, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), tokens_cnts, cumsum); + + cudaFree(tokens_cnts); + cudaFree(cumsum); + }); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc b/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc deleted file mode 100644 index 6903a357648..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_sum_kernels.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include -#include - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); -} \ No newline at end of file From 38a8df54abb5f259b1674c8ca5c6ad8856806332 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 25 Dec 2024 13:01:05 -0800 Subject: [PATCH 3/5] upd --- .../src/sgl-kernel/csrc/moe_align_kernel.cu | 84 ++++++++----------- 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index eae14b92ab3..15c6bf4710f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -1,9 +1,9 @@ // Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu #include -#include #include #include +#include #include #include @@ -44,14 +44,9 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t } template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* expert_ids, - int32_t* total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel, - int32_t* tokens_cnts, - int32_t* cumsum) { +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; @@ -74,8 +69,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, if (threadIdx.x < num_experts) { tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } @@ -85,10 +79,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } @@ -100,8 +91,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * blocks and stores the corresponding expert_id for each block. */ if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } } @@ -121,47 +111,39 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * processed by the expert with expert_id within the current thread's token * shard. */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - - const int32_t mem_tokens_cnts = - ((num_experts + 1) * num_experts) * sizeof(int32_t); - const int32_t mem_cumsum = - (num_experts + 1) * sizeof(int32_t); - - // allocate global memory - int32_t* tokens_cnts; - int32_t* cumsum; - cudaMalloc(&tokens_cnts, mem_tokens_cnts); - cudaMalloc(&cumsum, mem_cumsum); - - // set dynamic shared mem - auto kernel = moe_align_block_size_kernel; - kernel<<<1, num_thread, 0, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), tokens_cnts, cumsum); - - cudaFree(tokens_cnts); - cudaFree(cumsum); - }); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + const int32_t mem_tokens_cnts = ((num_experts + 1) * num_experts) * sizeof(int32_t); + const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); + + // allocate global memory + int32_t* tokens_cnts; + int32_t* cumsum; + cudaMalloc(&tokens_cnts, mem_tokens_cnts); + cudaMalloc(&cumsum, mem_cumsum); + + // set dynamic shared mem + auto kernel = moe_align_block_size_kernel; + kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), tokens_cnts, cumsum); + + cudaFree(tokens_cnts); + cudaFree(cumsum); + }); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { From 68400b70ee96e4ff69dfaf2e24e4a1548a692b69 Mon Sep 17 00:00:00 2001 From: ispobock Date: Wed, 25 Dec 2024 13:13:24 -0800 Subject: [PATCH 4/5] add test --- sgl-kernel/tests/test_moe_align.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 sgl-kernel/tests/test_moe_align.py diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py new file mode 100644 index 00000000000..5503cea0f3f --- /dev/null +++ b/sgl-kernel/tests/test_moe_align.py @@ -0,0 +1,26 @@ +import torch +from sgl_kernel import moe_align_block_size + + +def test_moe_align_block_size(): + num_experts = 256 + block_size = 128 + topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + + moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) + + +test_moe_align_block_size() From 2031abecfd45553282af9945967eb44dcc2b68d5 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 25 Dec 2024 13:20:57 -0800 Subject: [PATCH 5/5] upd --- sgl-kernel/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index a9111119807..a93a6cc280d 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post5" +version = "0.0.2.post6" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8"