Skip to content

Commit

Permalink
[Submodule] Change FlashInfer to import (sgl-project#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Feb 7, 2024
1 parent cb8e198 commit 26c3494
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 24 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
1 change: 0 additions & 1 deletion 3rdparty/flashinfer
Submodule flashinfer deleted from 88b949
8 changes: 5 additions & 3 deletions docs/flashinfer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation.

### Install flashinfer

Note: The compilation can take a very long time.
You can install flashinfer via pip as follows for CUDA 12.1.

```bash
git submodule update --init --recursive
pip install 3rdparty/flashinfer/python
pip install flashinfer -i https://flashinfer.ai/whl/cu121/
```

You can look for other CUDA versions in https://github.com/flashinfer-ai/flashinfer?tab=readme-ov-file#installation. If there is no desire version for your environment,
please build it from source (the compilation takes a long time).

### Run a Server With Flashinfer Mode

Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.
Expand Down
8 changes: 0 additions & 8 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,7 @@ def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):

o = input_metadata.prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.qo_indptr,
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
allow_fp16_qk_reduction=True,
)

return o.view(-1, self.tp_q_head_num * self.head_dim)
Expand All @@ -114,9 +109,6 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
o = input_metadata.decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
)

return o.view(-1, self.tp_q_head_num * self.head_dim)
Expand Down
21 changes: 12 additions & 9 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class InputMetadata:
decode_wrapper = None

def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)

self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
Expand All @@ -107,11 +112,7 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size,), dtype=torch.int32, device="cuda"
)

from flashinfer.ops import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
Expand All @@ -120,19 +121,21 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.batch_size,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
Expand Down

0 comments on commit 26c3494

Please sign in to comment.