diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 7c17e71a9e2..04393c8e716 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -230,15 +230,11 @@ template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); void* local_inp_buffer = param.local_input_buffer_ptr; - CHECKCUDA( + CHECK_CUDA_SUCCESS( cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot"); - auto last_error = cudaGetLastError(); - if (last_error != cudaSuccess) { - printf("cuda error: %s\n", cudaGetErrorString(last_error)); - assert(false && "Error before launching the kernel"); - } + CHECK_CUDA_SUCCESS(cudaGetLastError()); size_t elts_per_thread = 16 / sizeof(T); auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); @@ -258,11 +254,7 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy default: break; } - last_error = cudaGetLastError(); - if (last_error != cudaSuccess) { - printf("cuda error: %s\n", cudaGetErrorString(last_error)); - assert(false && "Error after launching the kernel"); - } + CHECK_CUDA_SUCCESS(cudaGetLastError()); } void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index 2d2767e6794..46f196c0447 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -20,25 +20,7 @@ #include #include #include - -#define FatalError(s) \ - do { \ - std::stringstream _where, _message; \ - _where << __FILE__ << ':' << __LINE__; \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ - std::cerr << _message.str() << "\nAborting...\n"; \ - assert(false); \ - exit(1); \ - } while (0) - -#define CHECKCUDA(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#include "utils.hpp" namespace trt_llm { constexpr size_t WARP_SIZE = 32; diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp new file mode 100644 index 00000000000..eefbee8ae56 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp @@ -0,0 +1,35 @@ +#pragma once +#include +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc index efc2f0cd951..6cc3ae152ca 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc @@ -1,15 +1,10 @@ #include +#include "utils.hpp" torch::Tensor warp_reduce_cuda(torch::Tensor input); -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - torch::Tensor warp_reduce(torch::Tensor input) { - CHECK_INPUT(input); + CHECK_CUDA_INPUT(input); return warp_reduce_cuda(input); }