Skip to content

Commit

Permalink
add utils.hpp for basic check
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Dec 14, 2024
1 parent 3a0b36a commit ce283ff
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
14 changes: 3 additions & 11 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,11 @@ template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECKCUDA(
CHECK_CUDA_SUCCESS(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));

assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
auto last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
printf("cuda error: %s\n", cudaGetErrorString(last_error));
assert(false && "Error before launching the kernel");
}
CHECK_CUDA_SUCCESS(cudaGetLastError());

size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
Expand All @@ -258,11 +254,7 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
default:
break;
}
last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
printf("cuda error: %s\n", cudaGetErrorString(last_error));
assert(false && "Error after launching the kernel");
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}

void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
Expand Down
20 changes: 1 addition & 19 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,7 @@
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>

#define FatalError(s) \
do { \
std::stringstream _where, _message; \
_where << __FILE__ << ':' << __LINE__; \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
std::cerr << _message.str() << "\nAborting...\n"; \
assert(false); \
exit(1); \
} while (0)

#define CHECKCUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#include "utils.hpp"

namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
Expand Down
35 changes: 35 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <sstream>
#include <torch/extension.h>

struct cuda_error : public std::runtime_error {
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error(const char* message) : std::runtime_error(message) {}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
};

#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
std::stringstream _message; \
auto s = cudaGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)

#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
9 changes: 2 additions & 7 deletions sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#include <torch/extension.h>
#include "utils.hpp"

torch::Tensor warp_reduce_cuda(torch::Tensor input);

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

torch::Tensor warp_reduce(torch::Tensor input) {
CHECK_INPUT(input);
CHECK_CUDA_INPUT(input);
return warp_reduce_cuda(input);
}

Expand Down

0 comments on commit ce283ff

Please sign in to comment.