Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Kernel] Fixup for CUTLASS kernels in CUDA graphs (vllm-project#4954)
Browse files Browse the repository at this point in the history
Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
  • Loading branch information
tlrmchlsmth authored and robertgshaw2-redhat committed Jul 14, 2024
1 parent 09ba2c0 commit bb60970
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
6 changes: 5 additions & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <stddef.h>
#include <torch/extension.h>

#include <ATen/cuda/CUDAContext.h>

// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
Expand Down Expand Up @@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get());
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}

Expand Down
5 changes: 4 additions & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <torch/extension.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>
Expand Down Expand Up @@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
TORCH_CHECK(workspace_size == 0);

cutlass::Status status = gemm_op.run(args);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, stream);
CUTLASS_CHECK(status);
}
} // namespace
Expand Down
41 changes: 41 additions & 0 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,44 @@ def test_cutlass_subset():
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

def __init__(self, b, scale_a, scale_b, out_dtype):
super().__init__()
self.b = b
self.scale_a = scale_a
self.scale_b = scale_b
self.out_dtype = out_dtype

def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)


def test_cutlass_cuda_graph():
m, n, k = 512, 512, 512

a = to_int8(torch.randn((m, k), device="cuda"))
b = to_int8(torch.randn((n, k), device="cuda").t())

scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10)
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10)

# Construct a trivial model with a single layer that calls a CUTLASS kernel
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)

# Run the model with a cuda graph
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(a)
out.zero_()
g.replay()

baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)

0 comments on commit bb60970

Please sign in to comment.