Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format: add clang-format for sgl-kernel #2483

Merged
merged 2 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sgl-kernel/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 120
AllowShortFunctionsOnASingleLine: Empty
DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
5 changes: 4 additions & 1 deletion sgl-kernel/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: tree ln install build clean test
.PHONY: tree ln install build clean test format

tree:
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
Expand All @@ -17,3 +17,6 @@ clean:

test:
@pytest tests/

format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
10 changes: 4 additions & 6 deletions sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

torch::Tensor warp_reduce_cuda(torch::Tensor input);

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

torch::Tensor warp_reduce(torch::Tensor input) {
Expand Down
32 changes: 12 additions & 20 deletions sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;

val = warpReduceSum(val); // First reduce within warp
val = warpReduceSum(val); // First reduce within warp

if (lane == 0)
shared[wid] = val; // Write reduced value to shared memory
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory

__syncthreads(); // Wait for all partial reductions
__syncthreads(); // Wait for all partial reductions

// Read from shared memory only if that warp existed
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;

if (wid == 0)
val = warpReduceSum(val); // Final reduce within first warp
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp

return val;
}

template <typename scalar_t>
__global__ void warp_reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
int N) {

const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
scalar_t sum = 0;

// Grid-stride loop
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
sum += input[i];
}

Expand Down Expand Up @@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
// Allocate output tensor for partial sums
auto output = torch::empty({blocks}, input.options());

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
N);
}));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
}));

// Sum the partial results
return output.sum();
Expand Down
Loading