Skip to content

Commit

Permalink
format: add clang-format for sgl-kernel (#2483)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Dec 14, 2024
1 parent 2f9bd0f commit fccbfa3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
8 changes: 8 additions & 0 deletions sgl-kernel/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 120
AllowShortFunctionsOnASingleLine: Empty
DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
5 changes: 4 additions & 1 deletion sgl-kernel/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: tree ln install build clean test
.PHONY: tree ln install build clean test format

tree:
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
Expand All @@ -17,3 +17,6 @@ clean:

test:
@pytest tests/

format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
10 changes: 4 additions & 6 deletions sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

torch::Tensor warp_reduce_cuda(torch::Tensor input);

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

torch::Tensor warp_reduce(torch::Tensor input) {
Expand Down
32 changes: 12 additions & 20 deletions sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;

val = warpReduceSum(val); // First reduce within warp
val = warpReduceSum(val); // First reduce within warp

if (lane == 0)
shared[wid] = val; // Write reduced value to shared memory
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory

__syncthreads(); // Wait for all partial reductions
__syncthreads(); // Wait for all partial reductions

// Read from shared memory only if that warp existed
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;

if (wid == 0)
val = warpReduceSum(val); // Final reduce within first warp
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp

return val;
}

template <typename scalar_t>
__global__ void warp_reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
int N) {

const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
scalar_t sum = 0;

// Grid-stride loop
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
sum += input[i];
}

Expand Down Expand Up @@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
// Allocate output tensor for partial sums
auto output = torch::empty({blocks}, input.options());

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
N);
}));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
}));

// Sum the partial results
return output.sum();
Expand Down

0 comments on commit fccbfa3

Please sign in to comment.