Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Deepseek V2 MLA #10927

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
11 changes: 9 additions & 2 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat",
trust_remote_code=True,
max_model_len=16384,
# dtype="float16",
enforce_eager=True,
# max_num_seqs=1,
# block_size=128,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
215 changes: 142 additions & 73 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
from vllm.multimodal import MultiModalPlaceholderMap

try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer import BatchDecodeWithPagedKVCacheWrapper, BatchDecodeMlaWithPagedKVCacheWrapper

Check failure on line 9 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/flashinfer.py:9:81: E501 Line too long (100 > 80)
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper

from vllm.vllm_flash_attn import flash_attn_varlen_func

Check failure on line 13 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/attention/backends/flashinfer.py:13:38: F401 `vllm.vllm_flash_attn.flash_attn_varlen_func` imported but unused; consider using `importlib.util.find_spec` to test for availability
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
BatchDecodeMlaWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch
Expand Down Expand Up @@ -67,7 +68,9 @@
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
# IDEA(simon): We probably should create a new backend for MLA something like FLASHINFER_MLA.

Check failure on line 71 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/flashinfer.py:71:81: E501 Line too long (101 > 80)
return (num_blocks, 2, block_size, num_kv_heads, head_size)
# return (num_blocks, 1, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
Expand All @@ -86,7 +89,7 @@

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
return [256, 512] # [64, 128, 256, 512]

@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
Expand Down Expand Up @@ -117,8 +120,9 @@

def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
# self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._prefill_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper(
self._get_workspace_buffer(), )
return self._prefill_wrapper

def _get_decode_wrapper(self):
Expand All @@ -127,12 +131,13 @@
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (

Check failure on line 134 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/attention/backends/flashinfer.py:134:13: F841 Local variable `use_tensor_cores` is assigned to but never used
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
# self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._decode_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper(
self._get_workspace_buffer())
# "NHD",
# use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

@contextmanager
Expand Down Expand Up @@ -187,13 +192,20 @@
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (

Check failure on line 195 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/attention/backends/flashinfer.py:195:9: F841 Local variable `use_tensor_cores` is assigned to but never used
num_qo_heads // num_kv_heads > 4)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
assert torch.is_tensor(_indptr_buffer), f"{_indptr_buffer=}"
self._graph_decode_wrapper = (
# CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
BatchDecodeMlaWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer,
True,
_indptr_buffer,
self._graph_indices_buffer,
_last_page_len_buffer,
))
# "NHD",
# use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
Expand Down Expand Up @@ -265,6 +277,7 @@
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()

import math

Check failure on line 280 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/attention/backends/flashinfer.py:280:1: E402 Module level import not at top of file

@dataclass
class FlashInferMetadata(AttentionMetadata):
Expand All @@ -279,8 +292,8 @@

use_cuda_graph: bool = True

prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
prefill_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None

# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -356,14 +369,17 @@
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc,
# self.prefill_wrapper.end_forward()
self.prefill_wrapper.plan(
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.head_dim,
self.page_size,
sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in?

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/flashinfer.py:380:81: E501 Line too long (130 > 80)

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 380 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for // ("None" and "int") [operator]
data_type=self.data_type,
q_data_type=self.q_data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
Expand All @@ -379,17 +395,19 @@
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
# self.decode_wrapper.end_forward()

self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
self.num_qo_heads,
self.num_kv_heads,
# self.num_kv_heads,
self.head_dim,
self.page_size,
sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in?

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/flashinfer.py:408:81: E501 Line too long (126 > 80)

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for // ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for + ("None" and "int") [operator]

Check failure on line 408 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for // ("None" and "int") [operator]
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
# pos_encoding_mode="NONE",
# kv-cache data type.
data_type=self.data_type,
# query data type.
Expand Down Expand Up @@ -764,6 +782,8 @@
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.empty_tensor = torch.empty(0, device="cuda")

def forward(
self,
query: torch.Tensor,
Expand All @@ -781,25 +801,37 @@
"are not implemented for "
"FlashInferImpl")

key_rope = value
del value

num_tokens, N, LR = query.shape
assert N == self.num_heads
assert LR == self.head_size + self.head_size//8
qk_rope_head_dim = LR - self.head_size
assert qk_rope_head_dim == 64
# hidden_size = N * self.head_size

Check failure on line 812 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM300)

vllm/attention/backends/flashinfer.py:812:16: SIM300 Yoda condition detected

Check failure on line 813 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM300)

vllm/attention/backends/flashinfer.py:813:16: SIM300 Yoda condition detected
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
assert self.num_kv_heads == 1
kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes = self.alibi_slopes
logits_soft_cap = self.logits_soft_cap
# softmax_scale: float = self.scale
# window_size = self.sliding_window
# alibi_slopes = self.alibi_slopes
# logits_soft_cap = self.logits_soft_cap

num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
# num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, LR)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
key_rope = key_rope.view(-1, num_kv_heads,
head_size) # this is padded!

if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
key_rope,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
Expand All @@ -816,68 +848,102 @@

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert num_prefill_tokens == 0 and num_decode_tokens > 0, "only mla decode"

assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous(
) # Flashinfer requires query to be contiguous
assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa

query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]

key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
# query = query[:num_prefill_tokens]
decode_query = query#[num_prefill_tokens:]
# assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens, f"{decode_query.shape=}, {num_decode_tokens=}"

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
# query_nope = query[:, :, :head_size].contiguous()
# query_pe = query[:, :, head_size:].contiguous()
decode_query_nope = decode_query[:, :, :head_size].contiguous()
decode_query_pe = decode_query[:, :, head_size:].contiguous()

window_left = window_size[0] if window_size is not None else -1
# window_left = window_size[0] if window_size is not None else -1

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
assert False
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
prefill_output = torch.empty(num_prefill_tokens,
N,
head_size,
device="cuda")
# key = key[:num_prefill_tokens]
# key_rope = key_rope[:num_prefill_tokens, :, :qk_rope_head_dim]
# prefill_output = flash_attn_varlen_func(
# q=query,
# k=key,
# v=key_rope,
# cu_seqlens_q=prefill_meta.seq_start_loc,
# cu_seqlens_k=prefill_meta.seq_start_loc,
# max_seqlen_q=prefill_meta.max_prefill_seq_len,
# max_seqlen_k=prefill_meta.max_prefill_seq_len,
# softmax_scale=softmax_scale,
# causal=True,
# window_size=window_size,
# alibi_slopes=alibi_slopes,
# )
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
# prefill_output = prefill_meta.prefill_wrapper.run(
# query,
# kv_cache,
# logits_soft_cap=logits_soft_cap,
# causal=True,
# k_scale=k_scale,
# v_scale=v_scale,
# window_left=window_left)
paged_kpe_cache, _ = kv_cache[:, 1].split(
[qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1)

prefill_output = prefill_meta.prefill_wrapper.run(
q_nope=query_nope,
q_pe=query_pe,
paged_ckv_cache=kv_cache[:, 0],
paged_kpe_cache=kv_cache[:, 1],
# paged_kpe_cache=paged_kpe_cache,

Check failure on line 920 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "query_nope" is not defined [name-defined]

Check failure on line 920 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "query_nope" is not defined [name-defined]

Check failure on line 920 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "query_nope" is not defined [name-defined]

Check failure on line 920 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "query_nope" is not defined [name-defined]
# sm_scale=softmax_scale,

Check failure on line 921 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "query_pe" is not defined [name-defined]

Check failure on line 921 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "query_pe" is not defined [name-defined]

Check failure on line 921 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "query_pe" is not defined [name-defined]

Check failure on line 921 in vllm/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "query_pe" is not defined [name-defined]
# logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
v_scale=None, # v_scale,
# window_left=window_left
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_left)
# paged_kpe_cache, _ = kv_cache[:, 1].split(
# [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1)
# paged_kpe_cache = paged_kpe_cache.contiguous() # this is making of entire KV cache noooo
# # note: this shouldn't matter b/c FI assumes head_dim_kpe == head_dim_ckv//8

decode_output = decode_meta.decode_wrapper.run(
q_nope=decode_query_nope,
q_pe=decode_query_pe,
paged_ckv_cache=kv_cache[:, 0],
paged_kpe_cache=kv_cache[:, 1],
# paged_kpe_cache=paged_kpe_cache,
# sm_scale=softmax_scale,
# logits_soft_cap=logits_soft_cap,
# k_scale=k_scale,
# v_scale=v_scale,
# window_left=window_left
)

if prefill_output is None and decode_output is not None:
# Decode only batch.
Expand All @@ -894,4 +960,7 @@
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
assert output.shape == (
num_tokens, N,
head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}"
return output
Loading
Loading