Skip to content

Commit

Permalink
misc: attention kernel refactoring (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Oct 23, 2024
1 parent 0876f8a commit a7abc22
Show file tree
Hide file tree
Showing 8 changed files with 971 additions and 615 deletions.
923 changes: 525 additions & 398 deletions src/kernels/attention/flash_infer/attention_kernel.h

Large diffs are not rendered by default.

39 changes: 16 additions & 23 deletions src/kernels/attention/flash_infer/attention_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,18 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
float sm_scale,
float* alibi_slopes,
cudaStream_t stream) {
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
IdType *request_indices = nullptr, *qo_tile_indices = nullptr,
*kv_tile_indices = nullptr, *o_indptr = nullptr,
*merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
bool* block_valid_mask = nullptr;
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS();
IdType* request_indices = handler->GetRequestIndices<IdType>();
IdType* qo_tile_indices = handler->GetQOTileIndices<IdType>();
IdType* kv_tile_indices = handler->GetKVTileIndices<IdType>();
bool* block_valid_mask = handler->GetBlockValidMask();
IdType* o_indptr = handler->GetOIndptr<IdType>();
IdType* merge_indptr = handler->GetMergeIndptr<IdType>();
IdType* kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
WarpLayout warp_layout = handler->GetWarpLayout();
uint32_t padded_batch_size = handler->GetPaddedBatchSize();
uint32_t total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return mha_varlen_dispatch<WARP_LAYOUT,
Expand Down Expand Up @@ -145,7 +136,8 @@ void BatchPrefillWrapper::Plan(torch::Tensor float_workspace_buffer,
unsigned int num_kv_heads,
unsigned int head_dim,
unsigned int page_size,
torch::Tensor empty_q_data) {
torch::Tensor empty_q_data,
int32_t num_sm) {
CHECK_INPUT(float_workspace_buffer);
CHECK_INPUT(int_workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
Expand Down Expand Up @@ -182,7 +174,8 @@ void BatchPrefillWrapper::Plan(torch::Tensor float_workspace_buffer,
num_qo_heads,
num_kv_heads,
head_dim,
page_size);
page_size,
num_sm);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
Expand Down
9 changes: 4 additions & 5 deletions src/kernels/attention/flash_infer/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

#include <torch/torch.h>

#include <flashinfer/attention/warp_layout.cuh>

#include "handler.h"

namespace flashinfer {

class BatchPrefillWrapper {
public:
BatchPrefillWrapper(bool enable_cuda_graph)
: handler_(std::make_shared<flashinfer::BatchPrefillHandler>(
: handler_(std::make_unique<flashinfer::BatchPrefillHandler>(
enable_cuda_graph)) {}

void Plan(torch::Tensor float_workspace_buffer,
Expand All @@ -24,7 +22,8 @@ class BatchPrefillWrapper {
unsigned int num_kv_heads,
unsigned int head_dim,
unsigned page_size,
torch::Tensor empty_q_data);
torch::Tensor empty_q_data,
int32_t num_sm);

bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }

Expand All @@ -43,7 +42,7 @@ class BatchPrefillWrapper {
std::optional<torch::Tensor> alibi_slopes);

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
std::unique_ptr<flashinfer::BatchPrefillHandler> handler_;
};

} // namespace flashinfer
Loading

0 comments on commit a7abc22

Please sign in to comment.