From 95f19910b6851c2cffd20e19bc12770dae83ac61 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 4 Dec 2024 03:32:41 +0000 Subject: [PATCH 01/14] Annotated MHA --- vllm/model_executor/models/deepseek_v2.py | 164 +++++++++++++++++++++- 1 file changed, 163 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4cf4e6c358bf2..930fec66997e0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -338,7 +338,8 @@ def __init__( # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) - self.self_attn = DeepseekV2Attention( + # self.self_attn = DeepseekV2Attention( + self.self_attn = DeepseekV2MLAAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -404,6 +405,167 @@ def forward( hidden_states = self.mlp(hidden_states) return hidden_states, residual +class DeepseekV2MLAAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Note(simon): Added some symbols for shapes, hoping to help clarity. + self.hidden_size = hidden_size # H + self.qk_nope_head_dim = qk_nope_head_dim # P + self.qk_rope_head_dim = qk_rope_head_dim # R + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # P + R + self.v_head_dim = v_head_dim # V + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank # L + + self.num_heads = num_heads # N + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size # N' + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # NOTE(simon): This needs to implemented with the matrices absorption algorithm. + assert q_lora_rank is None, "Currently not supported" + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + # (H -> N(P+R)) + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + # (H -> (L+R)) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + # ((L -> (N(P+V))) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # (NV -> H) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(num_heads=self.num_local_heads, + head_size=256, + scale=self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # BH -> B(N(P+R)) -> BN(P+R) + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) + + # BN(P+R) -> BNP, BNR + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # BH -> B(L+R) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + # B(L+R) -> BL, BR + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + # B(L+R) -> B1(L+R) + latent_cache = latent_cache.unsqueeze(1) + # BL -> BL + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + # BL -> B(N'(P+V)) + kv = self.kv_b_proj(kv_a)[0] + # B(N'(P+V)) -> BN'(P+V) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + # BN'(P+V) -> BN'P, BN'V + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # B1(L+R) -> B1R + k_pe = latent_cache[:, :, self.kv_lora_rank:] + # BNR, B1R -> BNR, B1R + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + # BN(P+R) + q[..., self.qk_nope_head_dim:] = q_pe + # BN(P+R) + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + + q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * 256) + k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * 256) + v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) + + # B(N'V) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(-1, self.num_local_heads * self.v_head_dim) + + # B(N'V) -> BH + output, _ = self.o_proj(attn_output) + return output @support_torch_compile class DeepseekV2Model(nn.Module): From e7a56dacf62863312320b7e1d569fc7abe01b486 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 5 Dec 2024 11:10:27 +0000 Subject: [PATCH 02/14] wip, cached latents, used mla decode, but it generate gibberish --- examples/offline_inference.py | 2 +- vllm/attention/backends/flashinfer.py | 172 +++++++----- vllm/config.py | 23 +- vllm/model_executor/models/deepseek_v2.py | 326 ++++++++++++++++------ vllm/worker/cache_engine.py | 4 + 5 files changed, 366 insertions(+), 161 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 23cc6e8539431..da4a070033d40 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, max_model_len=16384, dtype="float16", enforce_eager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1a2024705eb04..419a09cba889c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -6,7 +6,7 @@ from vllm.multimodal import MultiModalPlaceholderMap try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer import BatchDecodeWithPagedKVCacheWrapper, BatchDecodeMlaWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper @@ -16,6 +16,7 @@ BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + BatchDecodeMlaWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -67,7 +68,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: + # IDEA(simon): We probably should create a new backend for MLA something like FLASHINFER_MLA. return (num_blocks, 2, block_size, num_kv_heads, head_size) + # return (num_blocks, 1, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -86,7 +89,7 @@ def copy_blocks( @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 128, 256] + return [256, 512] # [64, 128, 256, 512] @staticmethod def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: @@ -117,8 +120,9 @@ def _get_workspace_buffer(self): def _get_prefill_wrapper(self): if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + # self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._prefill_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( + self._get_workspace_buffer(),) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -129,10 +133,11 @@ def _get_decode_wrapper(self): self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._get_workspace_buffer(), - "NHD", - use_tensor_cores=use_tensor_cores) + # self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._decode_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( + self._get_workspace_buffer()) + # "NHD", + # use_tensor_cores=use_tensor_cores) return self._decode_wrapper @contextmanager @@ -189,11 +194,13 @@ def graph_capture_get_metadata_for_batch( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._graph_decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, "NHD", - use_tensor_cores) + self._graph_decode_wrapper = ( + # CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + BatchDecodeMlaWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, True, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, )) + # "NHD", + # use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.runner.kv_cache_dtype) @@ -279,8 +286,8 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = True - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + prefill_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None # Metadata for the prefill stage seq_start_loc: Optional[torch.Tensor] = None @@ -356,14 +363,17 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( - self.query_start_loc, + # self.prefill_wrapper.end_forward() + self.prefill_wrapper.plan( self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.num_qo_heads, + self.head_dim, + self.page_size, + sm_scale=self.head_dim**-0.5, # TODO(simon): should we explicitly pass this in? + data_type=self.data_type, + q_data_type=self.q_data_type) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -379,17 +389,18 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( + # self.decode_wrapper.end_forward() + self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], self.num_qo_heads, - self.num_kv_heads, + # self.num_kv_heads, self.head_dim, self.page_size, + sm_scale=self.head_dim**-0.5, # TODO(simon): should we explicitly pass this in? # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", + # pos_encoding_mode="NONE", # kv-cache data type. data_type=self.data_type, # query data type. @@ -764,6 +775,8 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.empty_tensor = torch.empty(0, device="cuda") + def forward( self, query: torch.Tensor, @@ -781,6 +794,15 @@ def forward( "are not implemented for " "FlashInferImpl") + key_rope = value + del value + + num_tokens, N, L_R = query.shape + qk_rope_head_dim = L_R - self.head_size + hidden_size = N * self.head_size + assert N == self.num_heads + assert qk_rope_head_dim == 64 + num_heads: int = self.num_heads head_size: int = self.head_size num_kv_heads: int = self.num_kv_heads @@ -790,16 +812,16 @@ def forward( alibi_slopes = self.alibi_slopes logits_soft_cap = self.logits_soft_cap - num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, head_size) + # num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, L_R) key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) + key_rope = key_rope.view(-1, num_kv_heads, head_size) # this is padded! if kv_cache.numel() > 0: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( key, - value, + key_rope, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), @@ -818,21 +840,23 @@ def forward( num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa query = query.contiguous( ) # Flashinfer requires query to be contiguous # Query for decode. KV is not needed because it is already cached. # QKV for prefill. decode_query = query[num_prefill_tokens:] query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + query_nope = query[:, :, :head_size] + query_pe = query[:, :, head_size:] + + decode_query_nope = decode_query[:, :, :head_size] + decode_query_pe = decode_query[:, :, head_size:] + window_left = window_size[0] if window_size is not None else -1 prefill_output: Optional[torch.Tensor] = None @@ -843,41 +867,62 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - ) + prefill_output = torch.empty(num_prefill_tokens, N, head_size, device="cuda") + # key = key[:num_prefill_tokens] + # key_rope = key_rope[:num_prefill_tokens, :, :qk_rope_head_dim] + # prefill_output = flash_attn_varlen_func( + # q=query, + # k=key, + # v=key_rope, + # cu_seqlens_q=prefill_meta.seq_start_loc, + # cu_seqlens_k=prefill_meta.seq_start_loc, + # max_seqlen_q=prefill_meta.max_prefill_seq_len, + # max_seqlen_k=prefill_meta.max_prefill_seq_len, + # softmax_scale=softmax_scale, + # causal=True, + # window_size=window_size, + # alibi_slopes=alibi_slopes, + # ) else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( - query, - kv_cache, - logits_soft_cap=logits_soft_cap, - causal=True, + # prefill_output = prefill_meta.prefill_wrapper.run( + # query, + # kv_cache, + # logits_soft_cap=logits_soft_cap, + # causal=True, + # k_scale=k_scale, + # v_scale=v_scale, + # window_left=window_left) + paged_kpe_cache, _ = kv_cache[:, 1].split([qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + + prefill_output = prefill_meta.prefill_wrapper.run( + q_nope=query_nope, + q_pe=query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=paged_kpe_cache, + # sm_scale=softmax_scale, + # logits_soft_cap=logits_soft_cap, k_scale=k_scale, - v_scale=v_scale, - window_left=window_left) + v_scale=None, # v_scale, + # window_left=window_left + ) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.forward( - decode_query, - kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale, - window_left=window_left) + paged_kpe_cache, _ = kv_cache[:, 1].split([qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + + decode_output = decode_meta.decode_wrapper.run( + q_nope=decode_query_nope, + q_pe=decode_query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=paged_kpe_cache, + # sm_scale=softmax_scale, + # logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=None, # v_scale, # NOTE(simon): there's a bug in FI now. https://github.com/flashinfer-ai/flashinfer/pull/650 + # window_left=window_left + ) if prefill_output is None and decode_output is not None: # Decode only batch. @@ -894,4 +939,5 @@ def forward( assert decode_meta.decode_query_len == 1 decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + assert output.shape == (num_tokens, N, head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" + return output diff --git a/vllm/config.py b/vllm/config.py index c87feaec3e5f6..7ab4bf7c8a445 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -596,13 +596,18 @@ def get_vocab_size(self) -> int: def get_hidden_size(self) -> int: return self.hf_text_config.hidden_size + @property + def _is_deepseek_v2(self) -> bool: + return hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type == 'deepseek_v2' + def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + if self._is_deepseek_v2: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 - return 256 + # return 256 + # TODO(simon): feature flag MLA + return self.hf_text_config.kv_lora_rank # + self.hf_text_config.qk_rope_head_dim if self.is_attention_free: return 0 @@ -661,6 +666,10 @@ def get_total_num_kv_heads(self) -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" + if self._is_deepseek_v2: + # TODO(simon): feature flag MLA + return 1 + total_num_kv_heads = self.get_total_num_kv_heads() # If tensor parallelism is used, we divide the number of KV heads by # the tensor parallel size. We will replicate the KV heads in the @@ -1788,15 +1797,15 @@ class PoolerConfig: step_tag_id: Optional[int] = None """ - If set, only the score corresponding to the ``step_tag_id`` in the + If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ returned_token_ids: Optional[List[int]] = None """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the ``math-shepherd-mistral-7b-prm`` model. """ @@ -2139,7 +2148,7 @@ class CompilationConfig(BaseModel): from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - custom inductor passes: see PassConfig for more details - + Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used for the same size. We need to capture all the sizes we want to use. diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 930fec66997e0..9056a5b032fb7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -319,93 +319,59 @@ def forward( output, _ = self.o_proj(attn_output) return output - -class DeepseekV2DecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - # self.self_attn = DeepseekV2Attention( - self.self_attn = DeepseekV2MLAAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation https://github.com/flashinfer-ai/flashinfer/pull/551. + + Deepseek's MLA attention works the following way: + * The key idea is to use a single latent vector to represent the entire KV cache. + * The attention should simulate a multi-head attention, while the compute is similar to multi-query attention. + * The dataflow is as follows, + + * B: batch/sequence length + * H: hidden size + * N: number of attention heads + * Lq: latent dimension for Q + * Lkv: latent dimension for K/V + * P: nope dimension, P+R is the actual head_dim in common attention. + * R: rope dimension, this slide of the head_dim goes through rotary embeddings. + * V: V head dim. + + # The reconstructed way, as implemented in DeepseekV2Attention: + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent and kv_a are normalized. + 3. The q_latent and kv_a are then projected up into the multi-head version. + q_latent goes from (B, Lq) to (B, N(P+R)) included the rope dimension, + which is splited into q_nope (B, N, P) and q_pe (B, N, R). + kv_a goes from (B, Lkv) to (B, N(P+V)) which has the nope dimensions for K and V, + which is splited into k_nope (B, N, P) and v (B, N, V). + 3. q_pe, k_pe are then passed through rotary embeddings. + 4. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from q_nope, q_pe, k_nope, k_pe. + 5. Attention is computued with q, k, v. + 6. The KV cache is updated with the new entries k (B, N, (P+R)) and v (B, N, V), we pad the head dim to 256 + so that the KV cache has consistent shape and works with a typical cache implementation. + 7. The attention computation returns (B, N, V), which is projected back to (B, H) using out projection. + + # The recommended way, as described in the paper: + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent and kv_a are normalized. + 3. Here's the change, we do not perform up the full up projection for q_latent, and there is no + up projection at all for kv_a. This is achieved by the technique of "weight absorption". The paper says + "Fortunately, due to the associative law of matrix multiplication, we can absorb WUK into WUQ, and WUV into WO" + * The q up projection turns (B, Lq) into (B, N(P+R)), we split it into W_UQ (Lq, N, P) and W_QR (Lq, N, R). + * The kv_a up projection turns (B, Lkv) into (B, N(P+V)), we split it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). + * The out projection turns (B, N, V) into (B, H), has shape W_O (V, H) + * We can precompute the product of W_UQ and W_UK into W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in attention. + * We can precompute the product of W_UV and W_O into W_UV_O (N, Lkv, H), which is possible due to V@O as the "epilogue" of attention + 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. The rotary embeddingss still need to be applied to q_pe and k_pe. + 5. By applying W_UQ_UK to q_latent, we have the new q_nope of shape (B, N, Lkv). + 6. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. + 6. The attention is computed with q, k, v. Note that we just performed a MQA attention with (LKv+R) as our head dim. + 7. The KV cache is updated using the new entries k (B, N, (Lkv+R)), which included the v and rope values. + 8. The attention computation returns (B, N, Lkv), which is projected back to (B, H) using W_UV_O. + + From @tsu-bin's calculation, we only want to use the absorption technique for decode. + """ def __init__( self, config: PretrainedConfig, @@ -443,7 +409,7 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # NOTE(simon): This needs to implemented with the matrices absorption algorithm. + # TODO(simon): implement matrix absorption for this, needed for deepseek v2.5 assert q_lora_rank is None, "Currently not supported" if self.q_lora_rank is not None: @@ -505,13 +471,40 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = Attention(num_heads=self.num_local_heads, + # The prefill attention will compute a multi-headed attention by up-projecting the latents. + # TODO(simon): enable this for prefill, and save only the latents. + self.prefill_attn = Attention(num_heads=self.num_local_heads, head_size=256, scale=self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.prefill_attn") + # The decode attention will compute a multi-query attention by directly operating on the latent. + self.decode_attn = Attention(num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, # + self.qk_rope_head_dim, # TODO(simon): pass in qk_rope_head_dim? but i don't think + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decode_attn") + + # To be computed during weight loading + # self.W_QR = None + # self.W_UQ_UK = None + # self.W_UV_O = None + + kv_b_proj_weight = self.kv_b_proj.weight.T + assert kv_b_proj_weight.shape == (self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim))}" + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + self.W_UK, self.W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # self.W_UK = self.W_UK.view(self.kv_lora_rank, self.num_local_heads * self.qk_nope_head_dim) + # self.W_UV = self.W_UV.view(self.kv_lora_rank, self.num_local_heads * self.v_head_dim) def forward( self, @@ -519,6 +512,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # TODO(simon): add prefill attn + # return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) + return self.forward_decode(positions, hidden_states, kv_cache, attn_metadata) + + def forward_prefill( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: # BH -> B(N(P+R)) -> BN(P+R) if self.q_lora_rank is not None: @@ -560,13 +564,155 @@ def forward( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) # B(N'V) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.prefill_attn(q, k, v, kv_cache, attn_metadata) attn_output = attn_output.view(-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(-1, self.num_local_heads * self.v_head_dim) # B(N'V) -> BH output, _ = self.o_proj(attn_output) return output + def forward_decode( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Let's implement the matrix absorption dataflow. + # We will start with applying the projection instead of fusing them. + B = hidden_states.shape[0] + + # Apply UQ and QR. + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) + + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + # print(f"{q.shape=}, {q_nope.shape=}, {q_pe.shape=}, {k_pe.shape=}, {kv_a.shape=}, {latent_cache.shape=}") + k_pe = k_pe.unsqueeze(1) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + # Apply UK, q_nope (B, N, P) @ W_UK (L, N, P) -> (B, N, L) + q_nope = torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK) + # essemble q, k, and v; here v is repurposed to represent k_pe + + q = torch.empty((B, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim), dtype=q.dtype, device=q.device) + q[..., :self.kv_lora_rank] = q_nope + q[..., self.kv_lora_rank:] = q_pe + # q = q.view(B, self.num_local_heads * (self.kv_lora_rank + self.qk_rope_head_dim)) + + k = kv_a + # The padding is only used for kv storage. + v = torch.nn.functional.pad(k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0).squeeze(1) + assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" + + attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) + + assert attn_output.shape == (B, self.num_local_heads, self.kv_lora_rank), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.v_head_dim=}" + # idk why but the attn_output is fp32 + attn_output = attn_output.to(q.dtype) + # Apply UV, (B, N, L) @ W_UV (L, N, V) -> (B, N, V) + attn_output = torch.einsum("bnl,lnv->bnv", attn_output, self.W_UV) + attn_output = attn_output.reshape(B, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + # self.self_attn = DeepseekV2Attention( + self.self_attn = DeepseekV2MLAAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + @support_torch_compile class DeepseekV2Model(nn.Module): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..d226f96c8b418 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -109,6 +109,10 @@ def get_cache_block_size( parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size + # if model_config._is_deepseek_v2: # MLA share the K and V cache in one latent vector. + # value_cache_block = 0 + # else: + # TODO(simon): for MLA, this is repurpose for rope cache (64) but it is smaller than key cache (512). value_cache_block = key_cache_block total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": From 5bb75864450fd83e7db6dd36954775092ef0e223 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 6 Dec 2024 06:37:07 +0000 Subject: [PATCH 03/14] impl vannilla prefill w/kv cache, now it generate with matabsob, debugging fi mla kernel issue --- examples/offline_inference.py | 2 +- vllm/attention/backends/flashinfer.py | 7 +- vllm/model_executor/models/deepseek_v2.py | 97 ++++++++++++++++++----- 3 files changed, 84 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index da4a070033d40..5454c7d072280 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, max_model_len=16384, dtype="float16", enforce_eager=True) +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, max_model_len=16384, dtype="float16", enforce_eager=True, max_num_seqs=1, block_size=128) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 419a09cba889c..e676e8a155587 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -838,6 +838,8 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens + assert num_prefill_tokens == 0 and num_decode_tokens > 0, "only mla decode" + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ @@ -900,7 +902,8 @@ def forward( q_nope=query_nope, q_pe=query_pe, paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=paged_kpe_cache, + paged_kpe_cache=kv_cache[:, 1], + # paged_kpe_cache=paged_kpe_cache, # sm_scale=softmax_scale, # logits_soft_cap=logits_soft_cap, k_scale=k_scale, @@ -916,7 +919,7 @@ def forward( q_nope=decode_query_nope, q_pe=decode_query_pe, paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=paged_kpe_cache, + paged_kpe_cache=kv_cache[:, 1], # sm_scale=softmax_scale, # logits_soft_cap=logits_soft_cap, k_scale=k_scale, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 9056a5b032fb7..d94e4e4d1b705 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -319,6 +319,9 @@ def forward( output, _ = self.o_proj(attn_output) return output +from vllm.attention.backends.flash_attn import flash_attn_varlen_func, _get_query_key_seq_metadata, AttentionType +from vllm import _custom_ops as ops + class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation https://github.com/flashinfer-ai/flashinfer/pull/551. @@ -473,13 +476,13 @@ def __init__( # The prefill attention will compute a multi-headed attention by up-projecting the latents. # TODO(simon): enable this for prefill, and save only the latents. - self.prefill_attn = Attention(num_heads=self.num_local_heads, - head_size=256, - scale=self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.prefill_attn") + # self.prefill_attn = Attention(num_heads=self.num_local_heads, + # head_size=256, + # scale=self.scaling, + # num_kv_heads=self.num_local_heads, + # cache_config=cache_config, + # quant_config=quant_config, + # prefix=f"{prefix}.prefill_attn") # The decode attention will compute a multi-query attention by directly operating on the latent. self.decode_attn = Attention(num_heads=self.num_local_heads, head_size=self.kv_lora_rank, # + self.qk_rope_head_dim, # TODO(simon): pass in qk_rope_head_dim? but i don't think @@ -513,9 +516,12 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # TODO(simon): add prefill attn - # return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) - return self.forward_decode(positions, hidden_states, kv_cache, attn_metadata) + # TODO(simon): support append/chunked prefill by two kernels, or using the decode kernel somehow. + + if attn_metadata.prefill_metadata: + return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) + if attn_metadata.decode_metadata: + return self.forward_decode(positions, hidden_states, kv_cache, attn_metadata) def forward_prefill( self, @@ -559,12 +565,38 @@ def forward_prefill( k[..., :self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim:] = k_pe - q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * 256) - k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * 256) - v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) + # write the latent and rope to kv cache + to_cache_key = kv_a.unsqueeze(1) + to_cache_key_rope = torch.nn.functional.pad(k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) + if kv_cache.numel() > 0: + ops.reshape_and_cache_flash( + to_cache_key, + to_cache_key_rope, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype="auto", # TODO: remove hard code + k_scale=1.0, + v_scale=1.0, + ) - # B(N'V) - attn_output = self.prefill_attn(q, k, v, kv_cache, attn_metadata) + # run the prefill kernels + q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0) + k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0) + v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0) + + prefill_meta = attn_metadata.prefill_metadata + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata(prefill_meta, True, AttentionType.DECODER) + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + causal=True, + ) attn_output = attn_output.view(-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(-1, self.num_local_heads * self.v_head_dim) # B(N'V) -> BH @@ -590,7 +622,6 @@ def forward_decode( else: q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, k_pe = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -612,9 +643,37 @@ def forward_decode( v = torch.nn.functional.pad(k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0).squeeze(1) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" - attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) - - assert attn_output.shape == (B, self.num_local_heads, self.kv_lora_rank), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.v_head_dim=}" + # attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) + + # i just want to manually verify MLA is doing the right thing + # let's get all the previous kv cache and copy them here, run the MLA manually + paged_kv_indptr = attn_metadata.decode_metadata.paged_kv_indptr + paged_kv_indices = attn_metadata.decode_metadata.paged_kv_indices + paged_kv_last_page_len = attn_metadata.decode_metadata.paged_kv_last_page_len + + # debug: we always have batch size 1 and one page + assert paged_kv_indptr.cpu().tolist() == [0, 1], f"{paged_kv_indptr.cpu().tolist()=}" + paged_idx = paged_kv_indices[0] + full_latent_cache = kv_cache[paged_idx, 0] + full_rope_cache = kv_cache[paged_idx, 1] + # let's write k and v into the full cache at paged_kv_last_page_len-1 + full_latent_cache[paged_kv_last_page_len-1, :, :] = k + full_rope_cache[paged_kv_last_page_len-1, :, :] = v + full_latent_cache = full_latent_cache[:paged_kv_last_page_len, :, :] + full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self.qk_rope_head_dim] + full_kv_cache = torch.cat([full_latent_cache, full_rope_cache], dim=-1) + + # now let's run the MLA manually + q_B_N_LR = q + k_S_1_LR = full_kv_cache + v_S_1_L = full_latent_cache + import math + scale = 1.0/math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) + attn_scores = torch.einsum("bnl,snl->nbs", q_B_N_LR, k_S_1_LR) * scale + attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) + attn_output = torch.einsum("nbs,snl->bnl", attn_probs, v_S_1_L) + + assert attn_output.shape == (B, self.num_local_heads, self.kv_lora_rank), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.kv_lora_rank=}" # idk why but the attn_output is fp32 attn_output = attn_output.to(q.dtype) # Apply UV, (B, N, L) @ W_UV (L, N, V) -> (B, N, V) From feb6ba3699ff9d3a66687365a339799c52c8fef0 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 6 Dec 2024 06:40:10 +0000 Subject: [PATCH 04/14] lint --- examples/offline_inference.py | 10 +- vllm/attention/backends/flashinfer.py | 65 ++++++---- vllm/config.py | 6 +- vllm/model_executor/models/deepseek_v2.py | 141 +++++++++++++--------- 4 files changed, 139 insertions(+), 83 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 5454c7d072280..15b8c0af7146a 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,13 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, max_model_len=16384, dtype="float16", enforce_eager=True, max_num_seqs=1, block_size=128) +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + max_model_len=16384, + dtype="float16", + enforce_eager=True, + max_num_seqs=1, + block_size=128) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -19,4 +25,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e676e8a155587..2746334e5843f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -89,7 +89,7 @@ def copy_blocks( @staticmethod def get_supported_head_sizes() -> List[int]: - return [256, 512] # [64, 128, 256, 512] + return [256, 512] # [64, 128, 256, 512] @staticmethod def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: @@ -122,7 +122,7 @@ def _get_prefill_wrapper(self): if self._prefill_wrapper is None: # self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( - self._get_workspace_buffer(),) + self._get_workspace_buffer(), ) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -136,8 +136,8 @@ def _get_decode_wrapper(self): # self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( self._get_workspace_buffer()) - # "NHD", - # use_tensor_cores=use_tensor_cores) + # "NHD", + # use_tensor_cores=use_tensor_cores) return self._decode_wrapper @contextmanager @@ -197,10 +197,14 @@ def graph_capture_get_metadata_for_batch( self._graph_decode_wrapper = ( # CUDAGraphBatchDecodeWithPagedKVCacheWrapper( BatchDecodeMlaWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, True, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, )) - # "NHD", - # use_tensor_cores) + self._graph_decode_workspace_buffer, + True, + _indptr_buffer, + self._graph_indices_buffer, + _last_page_len_buffer, + )) + # "NHD", + # use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.runner.kv_cache_dtype) @@ -371,7 +375,8 @@ def begin_forward(self): self.num_qo_heads, self.head_dim, self.page_size, - sm_scale=self.head_dim**-0.5, # TODO(simon): should we explicitly pass this in? + sm_scale=self.head_dim** + -0.5, # TODO(simon): should we explicitly pass this in? data_type=self.data_type, q_data_type=self.q_data_type) if self.num_decode_tokens > 0: @@ -398,7 +403,8 @@ def begin_forward(self): # self.num_kv_heads, self.head_dim, self.page_size, - sm_scale=self.head_dim**-0.5, # TODO(simon): should we explicitly pass this in? + sm_scale=self.head_dim** + -0.5, # TODO(simon): should we explicitly pass this in? # Disable flashinfer's pos encoding and use vllm's rope. # pos_encoding_mode="NONE", # kv-cache data type. @@ -815,7 +821,8 @@ def forward( # num_tokens, hidden_size = query.shape query = query.view(-1, num_heads, L_R) key = key.view(-1, num_kv_heads, head_size) - key_rope = key_rope.view(-1, num_kv_heads, head_size) # this is padded! + key_rope = key_rope.view(-1, num_kv_heads, + head_size) # this is padded! if kv_cache.numel() > 0: # Use the same reshape and cache kernel as flash attention. @@ -869,7 +876,10 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - prefill_output = torch.empty(num_prefill_tokens, N, head_size, device="cuda") + prefill_output = torch.empty(num_prefill_tokens, + N, + head_size, + device="cuda") # key = key[:num_prefill_tokens] # key_rope = key_rope[:num_prefill_tokens, :, :qk_rope_head_dim] # prefill_output = flash_attn_varlen_func( @@ -896,7 +906,8 @@ def forward( # k_scale=k_scale, # v_scale=v_scale, # window_left=window_left) - paged_kpe_cache, _ = kv_cache[:, 1].split([qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + paged_kpe_cache, _ = kv_cache[:, 1].split( + [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) prefill_output = prefill_meta.prefill_wrapper.run( q_nope=query_nope, @@ -907,24 +918,26 @@ def forward( # sm_scale=softmax_scale, # logits_soft_cap=logits_soft_cap, k_scale=k_scale, - v_scale=None, # v_scale, + v_scale=None, # v_scale, # window_left=window_left ) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - paged_kpe_cache, _ = kv_cache[:, 1].split([qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + paged_kpe_cache, _ = kv_cache[:, 1].split( + [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) decode_output = decode_meta.decode_wrapper.run( - q_nope=decode_query_nope, - q_pe=decode_query_pe, - paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=kv_cache[:, 1], - # sm_scale=softmax_scale, - # logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=None, # v_scale, # NOTE(simon): there's a bug in FI now. https://github.com/flashinfer-ai/flashinfer/pull/650 - # window_left=window_left + q_nope=decode_query_nope, + q_pe=decode_query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=kv_cache[:, 1], + # sm_scale=softmax_scale, + # logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale= + None, # v_scale, # NOTE(simon): there's a bug in FI now. https://github.com/flashinfer-ai/flashinfer/pull/650 + # window_left=window_left ) if prefill_output is None and decode_output is not None: @@ -942,5 +955,7 @@ def forward( assert decode_meta.decode_query_len == 1 decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) - assert output.shape == (num_tokens, N, head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" + assert output.shape == ( + num_tokens, N, + head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" return output diff --git a/vllm/config.py b/vllm/config.py index 7ab4bf7c8a445..c188b00516f3c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -598,7 +598,9 @@ def get_hidden_size(self) -> int: @property def _is_deepseek_v2(self) -> bool: - return hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type == 'deepseek_v2' + return hasattr( + self.hf_text_config, + "model_type") and self.hf_text_config.model_type == 'deepseek_v2' def get_head_size(self) -> int: # TODO remove hard code @@ -607,7 +609,7 @@ def get_head_size(self) -> int: # we need to pad head_size 192 to 256 # return 256 # TODO(simon): feature flag MLA - return self.hf_text_config.kv_lora_rank # + self.hf_text_config.qk_rope_head_dim + return self.hf_text_config.kv_lora_rank # + self.hf_text_config.qk_rope_head_dim if self.is_attention_free: return 0 diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d94e4e4d1b705..e5e1933a72ae1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -319,9 +319,11 @@ def forward( output, _ = self.o_proj(attn_output) return output + from vllm.attention.backends.flash_attn import flash_attn_varlen_func, _get_query_key_seq_metadata, AttentionType from vllm import _custom_ops as ops + class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation https://github.com/flashinfer-ai/flashinfer/pull/551. @@ -345,9 +347,9 @@ class DeepseekV2MLAAttention(nn.Module): 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent and kv_a are normalized. 3. The q_latent and kv_a are then projected up into the multi-head version. q_latent goes from (B, Lq) to (B, N(P+R)) included the rope dimension, - which is splited into q_nope (B, N, P) and q_pe (B, N, R). + which is split into q_nope (B, N, P) and q_pe (B, N, R). kv_a goes from (B, Lkv) to (B, N(P+V)) which has the nope dimensions for K and V, - which is splited into k_nope (B, N, P) and v (B, N, V). + which is split into k_nope (B, N, P) and v (B, N, V). 3. q_pe, k_pe are then passed through rotary embeddings. 4. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from q_nope, q_pe, k_nope, k_pe. 5. Attention is computued with q, k, v. @@ -375,6 +377,7 @@ class DeepseekV2MLAAttention(nn.Module): From @tsu-bin's calculation, we only want to use the absorption technique for decode. """ + def __init__( self, config: PretrainedConfig, @@ -394,19 +397,19 @@ def __init__( ) -> None: super().__init__() # Note(simon): Added some symbols for shapes, hoping to help clarity. - self.hidden_size = hidden_size # H - self.qk_nope_head_dim = qk_nope_head_dim # P - self.qk_rope_head_dim = qk_rope_head_dim # R - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # P + R - self.v_head_dim = v_head_dim # V + self.hidden_size = hidden_size # H + self.qk_nope_head_dim = qk_nope_head_dim # P + self.qk_rope_head_dim = qk_rope_head_dim # R + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # P + R + self.v_head_dim = v_head_dim # V self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank # L + self.kv_lora_rank = kv_lora_rank # L - self.num_heads = num_heads # N + self.num_heads = num_heads # N tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size # N' + self.num_local_heads = num_heads // tp_size # N' self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta @@ -484,13 +487,15 @@ def __init__( # quant_config=quant_config, # prefix=f"{prefix}.prefill_attn") # The decode attention will compute a multi-query attention by directly operating on the latent. - self.decode_attn = Attention(num_heads=self.num_local_heads, - head_size=self.kv_lora_rank, # + self.qk_rope_head_dim, # TODO(simon): pass in qk_rope_head_dim? but i don't think - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decode_attn") + self.decode_attn = Attention( + num_heads=self.num_local_heads, + head_size=self. + kv_lora_rank, # + self.qk_rope_head_dim, # TODO(simon): pass in qk_rope_head_dim? but i don't think + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decode_attn") # To be computed during weight loading # self.W_QR = None @@ -498,14 +503,17 @@ def __init__( # self.W_UV_O = None kv_b_proj_weight = self.kv_b_proj.weight.T - assert kv_b_proj_weight.shape == (self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim))}" + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim) + ), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim))}" kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim, ) - self.W_UK, self.W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) # self.W_UK = self.W_UK.view(self.kv_lora_rank, self.num_local_heads * self.qk_nope_head_dim) # self.W_UV = self.W_UV.view(self.kv_lora_rank, self.num_local_heads * self.v_head_dim) @@ -519,9 +527,11 @@ def forward( # TODO(simon): support append/chunked prefill by two kernels, or using the decode kernel somehow. if attn_metadata.prefill_metadata: - return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) + return self.forward_prefill(positions, hidden_states, kv_cache, + attn_metadata) if attn_metadata.decode_metadata: - return self.forward_decode(positions, hidden_states, kv_cache, attn_metadata) + return self.forward_decode(positions, hidden_states, kv_cache, + attn_metadata) def forward_prefill( self, @@ -534,16 +544,20 @@ def forward_prefill( if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) # BN(P+R) -> BNP, BNR - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) # BH -> B(L+R) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] # B(L+R) -> BL, BR - kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # B(L+R) -> B1(L+R) latent_cache = latent_cache.unsqueeze(1) # BL -> BL @@ -551,7 +565,8 @@ def forward_prefill( # BL -> B(N'(P+V)) kv = self.kv_b_proj(kv_a)[0] # B(N'(P+V)) -> BN'(P+V) - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) # BN'(P+V) -> BN'P, BN'V k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # B1(L+R) -> B1R @@ -567,17 +582,18 @@ def forward_prefill( # write the latent and rope to kv cache to_cache_key = kv_a.unsqueeze(1) - to_cache_key_rope = torch.nn.functional.pad(k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) + to_cache_key_rope = torch.nn.functional.pad( + k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) if kv_cache.numel() > 0: ops.reshape_and_cache_flash( - to_cache_key, - to_cache_key_rope, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype="auto", # TODO: remove hard code - k_scale=1.0, - v_scale=1.0, + to_cache_key, + to_cache_key_rope, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype="auto", # TODO: remove hard code + k_scale=1.0, + v_scale=1.0, ) # run the prefill kernels @@ -586,7 +602,8 @@ def forward_prefill( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0) prefill_meta = attn_metadata.prefill_metadata - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata(prefill_meta, True, AttentionType.DECODER) + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata( + prefill_meta, True, AttentionType.DECODER) attn_output = flash_attn_varlen_func( q=q, k=k, @@ -596,8 +613,10 @@ def forward_prefill( max_seqlen_q=q_seq_len, max_seqlen_k=k_seq_len, causal=True, - ) - attn_output = attn_output.view(-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(-1, self.num_local_heads * self.v_head_dim) + ) + attn_output = attn_output.view( + -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) # B(N'V) -> BH output, _ = self.o_proj(attn_output) @@ -618,13 +637,17 @@ def forward_decode( if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, k_pe = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, k_pe = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) # print(f"{q.shape=}, {q_nope.shape=}, {q_pe.shape=}, {k_pe.shape=}, {kv_a.shape=}, {latent_cache.shape=}") k_pe = k_pe.unsqueeze(1) @@ -633,14 +656,19 @@ def forward_decode( q_nope = torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK) # essemble q, k, and v; here v is repurposed to represent k_pe - q = torch.empty((B, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim), dtype=q.dtype, device=q.device) + q = torch.empty((B, self.num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim), + dtype=q.dtype, + device=q.device) q[..., :self.kv_lora_rank] = q_nope q[..., self.kv_lora_rank:] = q_pe # q = q.view(B, self.num_local_heads * (self.kv_lora_rank + self.qk_rope_head_dim)) k = kv_a # The padding is only used for kv storage. - v = torch.nn.functional.pad(k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0).squeeze(1) + v = torch.nn.functional.pad( + k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], + value=0).squeeze(1) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" # attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) @@ -652,15 +680,18 @@ def forward_decode( paged_kv_last_page_len = attn_metadata.decode_metadata.paged_kv_last_page_len # debug: we always have batch size 1 and one page - assert paged_kv_indptr.cpu().tolist() == [0, 1], f"{paged_kv_indptr.cpu().tolist()=}" + assert paged_kv_indptr.cpu().tolist() == [ + 0, 1 + ], f"{paged_kv_indptr.cpu().tolist()=}" paged_idx = paged_kv_indices[0] full_latent_cache = kv_cache[paged_idx, 0] full_rope_cache = kv_cache[paged_idx, 1] # let's write k and v into the full cache at paged_kv_last_page_len-1 - full_latent_cache[paged_kv_last_page_len-1, :, :] = k - full_rope_cache[paged_kv_last_page_len-1, :, :] = v + full_latent_cache[paged_kv_last_page_len - 1, :, :] = k + full_rope_cache[paged_kv_last_page_len - 1, :, :] = v full_latent_cache = full_latent_cache[:paged_kv_last_page_len, :, :] - full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self.qk_rope_head_dim] + full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self. + qk_rope_head_dim] full_kv_cache = torch.cat([full_latent_cache, full_rope_cache], dim=-1) # now let's run the MLA manually @@ -668,24 +699,25 @@ def forward_decode( k_S_1_LR = full_kv_cache v_S_1_L = full_latent_cache import math - scale = 1.0/math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) + scale = 1.0 / math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) attn_scores = torch.einsum("bnl,snl->nbs", q_B_N_LR, k_S_1_LR) * scale attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) attn_output = torch.einsum("nbs,snl->bnl", attn_probs, v_S_1_L) - assert attn_output.shape == (B, self.num_local_heads, self.kv_lora_rank), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.kv_lora_rank=}" + assert attn_output.shape == ( + B, self.num_local_heads, self.kv_lora_rank + ), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.kv_lora_rank=}" # idk why but the attn_output is fp32 attn_output = attn_output.to(q.dtype) # Apply UV, (B, N, L) @ W_UV (L, N, V) -> (B, N, V) attn_output = torch.einsum("bnl,lnv->bnv", attn_output, self.W_UV) - attn_output = attn_output.reshape(B, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.reshape( + B, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output - - class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -995,3 +1027,4 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + \ No newline at end of file From f585a3f63434a4b4b3e8c279b1306057d97de029 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 6 Dec 2024 07:34:35 +0000 Subject: [PATCH 05/14] eager mode works --- examples/offline_inference.py | 7 ++- vllm/attention/backends/flashinfer.py | 63 +++++++++++---------- vllm/model_executor/models/deepseek_v2.py | 67 ++++++++++++----------- 3 files changed, 72 insertions(+), 65 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 15b8c0af7146a..a51885fec78be 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -14,10 +14,11 @@ llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, max_model_len=16384, - dtype="float16", + # dtype="float16", enforce_eager=True, - max_num_seqs=1, - block_size=128) + # max_num_seqs=1, + # block_size=128, + ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2746334e5843f..5303b66069f38 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -194,6 +194,7 @@ def graph_capture_get_metadata_for_batch( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) + assert torch.is_tensor(_indptr_buffer), f"{_indptr_buffer=}" self._graph_decode_wrapper = ( # CUDAGraphBatchDecodeWithPagedKVCacheWrapper( BatchDecodeMlaWithPagedKVCacheWrapper( @@ -276,6 +277,7 @@ def begin_forward(self, model_input): model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() model_input.attn_metadata.begin_forward() +import math @dataclass class FlashInferMetadata(AttentionMetadata): @@ -375,8 +377,7 @@ def begin_forward(self): self.num_qo_heads, self.head_dim, self.page_size, - sm_scale=self.head_dim** - -0.5, # TODO(simon): should we explicitly pass this in? + sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? data_type=self.data_type, q_data_type=self.q_data_type) if self.num_decode_tokens > 0: @@ -395,6 +396,7 @@ def begin_forward(self): assert self.decode_wrapper is not None # self.decode_wrapper.end_forward() + self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, @@ -403,8 +405,7 @@ def begin_forward(self): # self.num_kv_heads, self.head_dim, self.page_size, - sm_scale=self.head_dim** - -0.5, # TODO(simon): should we explicitly pass this in? + sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? # Disable flashinfer's pos encoding and use vllm's rope. # pos_encoding_mode="NONE", # kv-cache data type. @@ -803,23 +804,25 @@ def forward( key_rope = value del value - num_tokens, N, L_R = query.shape - qk_rope_head_dim = L_R - self.head_size - hidden_size = N * self.head_size + num_tokens, N, LR = query.shape assert N == self.num_heads + assert LR == self.head_size + self.head_size//8 + qk_rope_head_dim = LR - self.head_size assert qk_rope_head_dim == 64 + # hidden_size = N * self.head_size num_heads: int = self.num_heads head_size: int = self.head_size num_kv_heads: int = self.num_kv_heads + assert self.num_kv_heads == 1 kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes = self.alibi_slopes - logits_soft_cap = self.logits_soft_cap + # softmax_scale: float = self.scale + # window_size = self.sliding_window + # alibi_slopes = self.alibi_slopes + # logits_soft_cap = self.logits_soft_cap # num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, L_R) + query = query.view(-1, num_heads, LR) key = key.view(-1, num_kv_heads, head_size) key_rope = key_rope.view(-1, num_kv_heads, head_size) # this is padded! @@ -851,26 +854,26 @@ def forward( f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - query = query.contiguous( - ) # Flashinfer requires query to be contiguous + + query = query.contiguous() # Flashinfer requires query to be contiguous # Query for decode. KV is not needed because it is already cached. # QKV for prefill. - decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - query_nope = query[:, :, :head_size] - query_pe = query[:, :, head_size:] + # query = query[:num_prefill_tokens] + decode_query = query#[num_prefill_tokens:] + # assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens, f"{decode_query.shape=}, {num_decode_tokens=}" - decode_query_nope = decode_query[:, :, :head_size] - decode_query_pe = decode_query[:, :, head_size:] + # query_nope = query[:, :, :head_size].contiguous() + # query_pe = query[:, :, head_size:].contiguous() + decode_query_nope = decode_query[:, :, :head_size].contiguous() + decode_query_pe = decode_query[:, :, head_size:].contiguous() - window_left = window_size[0] if window_size is not None else -1 + # window_left = window_size[0] if window_size is not None else -1 prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: + assert False # We will use flash attention for prefill # when kv_cache is not provided. # This happens when vllm runs the profiling to @@ -924,19 +927,21 @@ def forward( if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - paged_kpe_cache, _ = kv_cache[:, 1].split( - [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + # paged_kpe_cache, _ = kv_cache[:, 1].split( + # [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + # paged_kpe_cache = paged_kpe_cache.contiguous() # this is making of entire KV cache noooo + # # note: this shouldn't matter b/c FI assumes head_dim_kpe == head_dim_ckv//8 decode_output = decode_meta.decode_wrapper.run( q_nope=decode_query_nope, q_pe=decode_query_pe, paged_ckv_cache=kv_cache[:, 0], paged_kpe_cache=kv_cache[:, 1], + # paged_kpe_cache=paged_kpe_cache, # sm_scale=softmax_scale, # logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale= - None, # v_scale, # NOTE(simon): there's a bug in FI now. https://github.com/flashinfer-ai/flashinfer/pull/650 + # k_scale=k_scale, + # v_scale=v_scale, # window_left=window_left ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e5e1933a72ae1..45dd4221972af 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -671,38 +671,40 @@ def forward_decode( value=0).squeeze(1) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" - # attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) - - # i just want to manually verify MLA is doing the right thing - # let's get all the previous kv cache and copy them here, run the MLA manually - paged_kv_indptr = attn_metadata.decode_metadata.paged_kv_indptr - paged_kv_indices = attn_metadata.decode_metadata.paged_kv_indices - paged_kv_last_page_len = attn_metadata.decode_metadata.paged_kv_last_page_len - - # debug: we always have batch size 1 and one page - assert paged_kv_indptr.cpu().tolist() == [ - 0, 1 - ], f"{paged_kv_indptr.cpu().tolist()=}" - paged_idx = paged_kv_indices[0] - full_latent_cache = kv_cache[paged_idx, 0] - full_rope_cache = kv_cache[paged_idx, 1] - # let's write k and v into the full cache at paged_kv_last_page_len-1 - full_latent_cache[paged_kv_last_page_len - 1, :, :] = k - full_rope_cache[paged_kv_last_page_len - 1, :, :] = v - full_latent_cache = full_latent_cache[:paged_kv_last_page_len, :, :] - full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self. - qk_rope_head_dim] - full_kv_cache = torch.cat([full_latent_cache, full_rope_cache], dim=-1) - - # now let's run the MLA manually - q_B_N_LR = q - k_S_1_LR = full_kv_cache - v_S_1_L = full_latent_cache - import math - scale = 1.0 / math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) - attn_scores = torch.einsum("bnl,snl->nbs", q_B_N_LR, k_S_1_LR) * scale - attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) - attn_output = torch.einsum("nbs,snl->bnl", attn_probs, v_S_1_L) + attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) + + # # debug: i just want to manually verify MLA is doing the right thing + # # let's get all the previous kv cache and copy them here, run the MLA manually + # paged_kv_indptr = attn_metadata.decode_metadata.paged_kv_indptr + # paged_kv_indices = attn_metadata.decode_metadata.paged_kv_indices + # paged_kv_last_page_len = attn_metadata.decode_metadata.paged_kv_last_page_len + + # # debug: we always have batch size 1 and one page + # assert paged_kv_indptr.cpu().tolist() == [ + # 0, 1 + # ], f"{paged_kv_indptr.cpu().tolist()=}" + # paged_idx = paged_kv_indices[0] + # full_latent_cache = kv_cache[paged_idx, 0] + # full_rope_cache = kv_cache[paged_idx, 1] + # # let's write k and v into the full cache at paged_kv_last_page_len-1 + # full_latent_cache[paged_kv_last_page_len - 1, :, :] = k + # full_rope_cache[paged_kv_last_page_len - 1, :, :] = v + # full_latent_cache = full_latent_cache[:paged_kv_last_page_len, :, :] + # full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self. + # qk_rope_head_dim] + # full_kv_cache = torch.cat([full_latent_cache, full_rope_cache], dim=-1) + + # # now let's run the MLA manually + # q_B_N_LR = q + # k_S_1_LR = full_kv_cache + # v_S_1_L = full_latent_cache + # import math + # scale = 1.0 / math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) + # attn_scores = torch.einsum("bnl,snl->nbs", q_B_N_LR, k_S_1_LR) * scale + # attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) + # attn_output_ref = torch.einsum("nbs,snl->bnl", attn_probs, v_S_1_L) + + # # assert torch.allclose(attn_output.sum(), attn_output_ref.sum()), f"{attn_output.sum()=}\n{attn_output_ref.sum()=}" assert attn_output.shape == ( B, self.num_local_heads, self.kv_lora_rank @@ -1027,4 +1029,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - \ No newline at end of file From 611acaaba149bd5930b48ddf3c69241f30752122 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 25 Dec 2024 03:44:50 +0000 Subject: [PATCH 06/14] reformat and add FLASHINFER_MLA --- examples/offline_inference.py | 18 +- vllm/attention/backends/flashinfer.py | 215 +++---- vllm/attention/backends/flashinfer_mla.py | 751 ++++++++++++++++++++++ vllm/attention/layer.py | 11 +- vllm/attention/selector.py | 15 +- vllm/config.py | 10 +- vllm/model_executor/models/deepseek_v2.py | 19 +- vllm/platforms/interface.py | 1 + vllm/worker/cache_engine.py | 14 +- vllm/worker/model_runner.py | 1 + 10 files changed, 880 insertions(+), 175 deletions(-) create mode 100644 vllm/attention/backends/flashinfer_mla.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index a51885fec78be..7c00321224360 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,14 +11,16 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - trust_remote_code=True, - max_model_len=16384, - # dtype="float16", - enforce_eager=True, - # max_num_seqs=1, - # block_size=128, - ) +llm = LLM( + model="deepseek-ai/DeepSeek-V2-Lite-Chat", + # model="deepseek-ai/DeepSeek-V2.5", tensor_parallel_size=8, + trust_remote_code=True, + max_model_len=16384, + # dtype="float16", + enforce_eager=True, + # max_num_seqs=1, + # block_size=128, +) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5303b66069f38..1a2024705eb04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -6,7 +6,7 @@ from vllm.multimodal import MultiModalPlaceholderMap try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper, BatchDecodeMlaWithPagedKVCacheWrapper + from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper @@ -16,7 +16,6 @@ BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None - BatchDecodeMlaWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -68,9 +67,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - # IDEA(simon): We probably should create a new backend for MLA something like FLASHINFER_MLA. return (num_blocks, 2, block_size, num_kv_heads, head_size) - # return (num_blocks, 1, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -89,7 +86,7 @@ def copy_blocks( @staticmethod def get_supported_head_sizes() -> List[int]: - return [256, 512] # [64, 128, 256, 512] + return [64, 128, 256] @staticmethod def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: @@ -120,9 +117,8 @@ def _get_workspace_buffer(self): def _get_prefill_wrapper(self): if self._prefill_wrapper is None: - # self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._prefill_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( - self._get_workspace_buffer(), ) + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") return self._prefill_wrapper def _get_decode_wrapper(self): @@ -133,11 +129,10 @@ def _get_decode_wrapper(self): self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - # self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self._decode_wrapper = BatchDecodeMlaWithPagedKVCacheWrapper( - self._get_workspace_buffer()) - # "NHD", - # use_tensor_cores=use_tensor_cores) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) return self._decode_wrapper @contextmanager @@ -194,18 +189,11 @@ def graph_capture_get_metadata_for_batch( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - assert torch.is_tensor(_indptr_buffer), f"{_indptr_buffer=}" - self._graph_decode_wrapper = ( - # CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - BatchDecodeMlaWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, - True, - _indptr_buffer, - self._graph_indices_buffer, - _last_page_len_buffer, - )) - # "NHD", - # use_tensor_cores) + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.runner.kv_cache_dtype) @@ -277,7 +265,6 @@ def begin_forward(self, model_input): model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() model_input.attn_metadata.begin_forward() -import math @dataclass class FlashInferMetadata(AttentionMetadata): @@ -292,8 +279,8 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = True - prefill_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None # Metadata for the prefill stage seq_start_loc: Optional[torch.Tensor] = None @@ -369,17 +356,14 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - # self.prefill_wrapper.end_forward() - self.prefill_wrapper.plan( + self.prefill_wrapper.end_forward() + self.prefill_wrapper.begin_forward( + self.query_start_loc, self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, - self.head_dim, - self.page_size, - sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? - data_type=self.data_type, - q_data_type=self.q_data_type) + self.num_qo_heads, self.num_kv_heads, self.head_dim, + self.page_size) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -395,19 +379,17 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - # self.decode_wrapper.end_forward() - - self.decode_wrapper.plan( + self.decode_wrapper.end_forward() + self.decode_wrapper.begin_forward( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], self.num_qo_heads, - # self.num_kv_heads, + self.num_kv_heads, self.head_dim, self.page_size, - sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? # Disable flashinfer's pos encoding and use vllm's rope. - # pos_encoding_mode="NONE", + pos_encoding_mode="NONE", # kv-cache data type. data_type=self.data_type, # query data type. @@ -782,8 +764,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.empty_tensor = torch.empty(0, device="cuda") - def forward( self, query: torch.Tensor, @@ -801,37 +781,25 @@ def forward( "are not implemented for " "FlashInferImpl") - key_rope = value - del value - - num_tokens, N, LR = query.shape - assert N == self.num_heads - assert LR == self.head_size + self.head_size//8 - qk_rope_head_dim = LR - self.head_size - assert qk_rope_head_dim == 64 - # hidden_size = N * self.head_size - num_heads: int = self.num_heads head_size: int = self.head_size num_kv_heads: int = self.num_kv_heads - assert self.num_kv_heads == 1 kv_cache_dtype: str = self.kv_cache_dtype - # softmax_scale: float = self.scale - # window_size = self.sliding_window - # alibi_slopes = self.alibi_slopes - # logits_soft_cap = self.logits_soft_cap + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap - # num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, LR) + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) key = key.view(-1, num_kv_heads, head_size) - key_rope = key_rope.view(-1, num_kv_heads, - head_size) # this is padded! + value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( key, - key_rope, + value, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), @@ -848,102 +816,68 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert num_prefill_tokens == 0 and num_decode_tokens > 0, "only mla decode" - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - query = query.contiguous() # Flashinfer requires query to be contiguous + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous # Query for decode. KV is not needed because it is already cached. # QKV for prefill. - # query = query[:num_prefill_tokens] - decode_query = query#[num_prefill_tokens:] - # assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens, f"{decode_query.shape=}, {num_decode_tokens=}" + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] - # query_nope = query[:, :, :head_size].contiguous() - # query_pe = query[:, :, head_size:].contiguous() - decode_query_nope = decode_query[:, :, :head_size].contiguous() - decode_query_pe = decode_query[:, :, head_size:].contiguous() + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens - # window_left = window_size[0] if window_size is not None else -1 + window_left = window_size[0] if window_size is not None else -1 prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: - assert False # We will use flash attention for prefill # when kv_cache is not provided. # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - prefill_output = torch.empty(num_prefill_tokens, - N, - head_size, - device="cuda") - # key = key[:num_prefill_tokens] - # key_rope = key_rope[:num_prefill_tokens, :, :qk_rope_head_dim] - # prefill_output = flash_attn_varlen_func( - # q=query, - # k=key, - # v=key_rope, - # cu_seqlens_q=prefill_meta.seq_start_loc, - # cu_seqlens_k=prefill_meta.seq_start_loc, - # max_seqlen_q=prefill_meta.max_prefill_seq_len, - # max_seqlen_k=prefill_meta.max_prefill_seq_len, - # softmax_scale=softmax_scale, - # causal=True, - # window_size=window_size, - # alibi_slopes=alibi_slopes, - # ) + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - # prefill_output = prefill_meta.prefill_wrapper.run( - # query, - # kv_cache, - # logits_soft_cap=logits_soft_cap, - # causal=True, - # k_scale=k_scale, - # v_scale=v_scale, - # window_left=window_left) - paged_kpe_cache, _ = kv_cache[:, 1].split( - [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) - - prefill_output = prefill_meta.prefill_wrapper.run( - q_nope=query_nope, - q_pe=query_pe, - paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=kv_cache[:, 1], - # paged_kpe_cache=paged_kpe_cache, - # sm_scale=softmax_scale, - # logits_soft_cap=logits_soft_cap, + prefill_output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, k_scale=k_scale, - v_scale=None, # v_scale, - # window_left=window_left - ) + v_scale=v_scale, + window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - # paged_kpe_cache, _ = kv_cache[:, 1].split( - # [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) - # paged_kpe_cache = paged_kpe_cache.contiguous() # this is making of entire KV cache noooo - # # note: this shouldn't matter b/c FI assumes head_dim_kpe == head_dim_ckv//8 - - decode_output = decode_meta.decode_wrapper.run( - q_nope=decode_query_nope, - q_pe=decode_query_pe, - paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=kv_cache[:, 1], - # paged_kpe_cache=paged_kpe_cache, - # sm_scale=softmax_scale, - # logits_soft_cap=logits_soft_cap, - # k_scale=k_scale, - # v_scale=v_scale, - # window_left=window_left - ) + decode_output = decode_meta.decode_wrapper.forward( + decode_query, + kv_cache, + sm_scale=softmax_scale, + logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale, + window_left=window_left) if prefill_output is None and decode_output is not None: # Decode only batch. @@ -960,7 +894,4 @@ def forward( assert decode_meta.decode_query_len == 1 decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) - assert output.shape == ( - num_tokens, N, - head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" - return output + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py new file mode 100644 index 0000000000000..f03a8220d9fee --- /dev/null +++ b/vllm/attention/backends/flashinfer_mla.py @@ -0,0 +1,751 @@ +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +import math +from functools import cached_property + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + from vllm.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeMlaWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class FlashInferMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> Type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferMLAState"]: + return FlashInferMLAState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + # NOTE(simon): we repurpose the "key" cache for latent, + # and "value" cache for rope. Until we have hybrid memory + # allocate, we are living with some memory waste. + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [512] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +class FlashInferMLAState(AttentionState): + + def __init__(self, runner): + self.runner = runner + + @cached_property + def _workspace_buffer(self): + return torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + + @cached_property + def _decode_wrapper(self): + return BatchDecodeMlaWithPagedKVCacheWrapper( + self._workspace_buffer) + + @contextmanager + def graph_capture(self, max_batch_size: int): + raise NotImplementedError("FlashInferMLAState does not support graph capture") + + def graph_clone(self, batch_size: int): + raise NotImplementedError("FlashInferMLAState does not support graph capture") + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + raise NotImplementedError("FlashInferMLAState does not support graph capture") + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + raise NotImplementedError("FlashInferMLAState does not support graph capture") + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + raise NotImplementedError("FlashInferMLAState does not support graph capture") + + def begin_forward(self, model_input): + model_input.attn_metadata.decode_wrapper = self._decode_wrapper + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMLAMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + device: torch.device = torch.device("cuda") + is_profile_run: bool = False + + def __post_init__(self): + supported_head_sizes = FlashInferMLABackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + # assert NotImplementedError("FlashInferMLAState does not support prefill") + # NOTE: only in profiling, treating it as noop + return + + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.prefill_wrapper.plan( + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], + self.num_qo_heads, + self.head_dim, + self.page_size, + sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? + data_type=self.data_type, + q_data_type=self.q_data_type) + + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + # self.decode_wrapper.end_forward() + + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + # self.num_kv_heads, + self.head_dim, + self.page_size, + sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? + # Disable flashinfer's pos encoding and use vllm's rope. + # pos_encoding_mode="NONE", + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMLAMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMLAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + raise NotImplementedError("FlashInferMLAMetadata does not support multi-step") + + +class FlashInferMLAMetadataBuilder(AttentionMetadataBuilder[FlashInferMLAMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMLAMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run) + + +class FlashInferMLAImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + unsupported_features = [alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError(f"FlashInferMLAImpl does not support {unsupported_features}") + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: str = AttentionType.DECODER, + ) -> torch.Tensor: + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + key_rope = value + del value + + num_tokens, N, LR = query.shape + assert N == self.num_heads + assert LR == self.head_size + self.head_size//8 + qk_rope_head_dim = LR - self.head_size + assert qk_rope_head_dim == 64 + # hidden_size = N * self.head_size + + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + assert self.num_kv_heads == 1 + kv_cache_dtype: str = self.kv_cache_dtype + # softmax_scale: float = self.scale + # window_size = self.sliding_window + # alibi_slopes = self.alibi_slopes + # logits_soft_cap = self.logits_soft_cap + + # num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, LR) + key = key.view(-1, num_kv_heads, head_size) + key_rope = key_rope.view(-1, num_kv_heads, + head_size) # this is padded! + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + key_rope, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert num_prefill_tokens == 0 and num_decode_tokens > 0, "only mla decode" + + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + + query = query.contiguous() # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + # query = query[:num_prefill_tokens] + decode_query = query#[num_prefill_tokens:] + # assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens, f"{decode_query.shape=}, {num_decode_tokens=}" + + # query_nope = query[:, :, :head_size].contiguous() + # query_pe = query[:, :, head_size:].contiguous() + decode_query_nope = decode_query[:, :, :head_size].contiguous() + decode_query_pe = decode_query[:, :, head_size:].contiguous() + + # window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # only in prefill, use run FA + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + causal=True, + ) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + # paged_kpe_cache, _ = kv_cache[:, 1].split( + # [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) + # paged_kpe_cache = paged_kpe_cache.contiguous() # this is making of entire KV cache noooo + # # note: this shouldn't matter b/c FI assumes head_dim_kpe == head_dim_ckv//8 + + decode_output = decode_meta.decode_wrapper.run( + q_nope=decode_query_nope, + q_pe=decode_query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=kv_cache[:, 1], + # paged_kpe_cache=paged_kpe_cache, + # sm_scale=softmax_scale, + # logits_soft_cap=logits_soft_cap, + # k_scale=k_scale, + # v_scale=v_scale, + # window_left=window_left + ) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + assert output.shape == ( + num_tokens, N, + head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" + return output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17157617248f7..c0ff223c14617 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -40,6 +40,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, prefix: str = "", ) -> None: super().__init__() @@ -90,9 +91,13 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, - block_size, is_attention_free, - blocksparse_params is not None) + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d263839705690..618ecb0b7ef11 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -81,6 +81,7 @@ def get_attn_backend( block_size: int, is_attention_free: bool, is_blocksparse: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -95,6 +96,7 @@ def get_attn_backend( is_attention_free=is_attention_free, is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, ) @@ -107,6 +109,7 @@ def _cached_get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_v1: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") @@ -115,7 +118,7 @@ def _cached_get_attn_backend( return BlocksparseFlashAttentionBackend backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, - is_attention_free, use_v1) + is_attention_free, use_v1, use_mla) if backend == _Backend.FLASH_ATTN: logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -155,6 +158,10 @@ def _cached_get_attn_backend( logger.info("Using Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend + elif backend == _Backend.FLASHINFER_MLA: + logger.info("Using Flashinfer MLA backend.") + from vllm.attention.backends.flashinfer_mla import FlashInferMLABackend + return FlashInferMLABackend elif backend == _Backend.HPU_ATTN: logger.info("Using HPUAttention backend.") from vllm.attention.backends.hpu_attn import HPUAttentionBackend @@ -176,7 +183,8 @@ def which_attn_to_use(head_size: int, kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - use_v1: bool = False) -> _Backend: + use_v1: bool = False, + use_mla: bool = False) -> _Backend: """Returns which flash attention backend to use.""" # Default case. selected_backend = _Backend.FLASH_ATTN @@ -210,6 +218,9 @@ def which_attn_to_use(head_size: int, if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 + if use_mla: + return _Backend.FLASHINFER_MLA + # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: if not current_platform.has_device_capability(80): diff --git a/vllm/config.py b/vllm/config.py index c188b00516f3c..20d864869ef8d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -597,14 +597,16 @@ def get_hidden_size(self) -> int: return self.hf_text_config.hidden_size @property - def _is_deepseek_v2(self) -> bool: - return hasattr( + def is_deepseek_v2(self) -> bool: + result = hasattr( self.hf_text_config, "model_type") and self.hf_text_config.model_type == 'deepseek_v2' + assert result + return result def get_head_size(self) -> int: # TODO remove hard code - if self._is_deepseek_v2: + if self.is_deepseek_v2: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 # return 256 @@ -668,7 +670,7 @@ def get_total_num_kv_heads(self) -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" - if self._is_deepseek_v2: + if self.is_deepseek_v2: # TODO(simon): feature flag MLA return 1 diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 45dd4221972af..886d0ef8aae9c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -487,15 +487,14 @@ def __init__( # quant_config=quant_config, # prefix=f"{prefix}.prefill_attn") # The decode attention will compute a multi-query attention by directly operating on the latent. - self.decode_attn = Attention( - num_heads=self.num_local_heads, - head_size=self. - kv_lora_rank, # + self.qk_rope_head_dim, # TODO(simon): pass in qk_rope_head_dim? but i don't think - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decode_attn") + self.attn = Attention(num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True) # To be computed during weight loading # self.W_QR = None @@ -671,7 +670,7 @@ def forward_decode( value=0).squeeze(1) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" - attn_output = self.decode_attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) # # debug: i just want to manually verify MLA is doing the right thing # # let's get all the previous kv cache and copy them here, run the MLA manually diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3328665029039..3d6cd7418bb1d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -19,6 +19,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() + FLASHINFER_MLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index d226f96c8b418..f6e1725aa194c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -52,11 +52,13 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(self.head_size, - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - model_config.is_attention_free) + self.attn_backend = get_attn_backend( + self.head_size, + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free, + use_mla=model_config.is_deepseek_v2) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( @@ -109,7 +111,7 @@ def get_cache_block_size( parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size - # if model_config._is_deepseek_v2: # MLA share the K and V cache in one latent vector. + # if model_config.is_deepseek_v2: # MLA share the K and V cache in one latent vector. # value_cache_block = 0 # else: # TODO(simon): for MLA, this is repurpose for rope cache (64) but it is smaller than key cache (512). diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1f654a9cce465..9419d90aea9ae 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1060,6 +1060,7 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.is_deepseek_v2, ) if needs_attn_backend else None if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()( From 2ccc97d827d50298d838add2480c6137a6f6c4ec Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 25 Dec 2024 06:33:07 +0000 Subject: [PATCH 07/14] moved prefill FA into FLASHINFER_MLA backend --- vllm/attention/backends/flashinfer_mla.py | 320 ++++++++++------------ vllm/model_executor/models/deepseek_v2.py | 193 +++++-------- 2 files changed, 209 insertions(+), 304 deletions(-) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index f03a8220d9fee..cf960cf56a772 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -9,15 +9,15 @@ try: from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper - from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeMlaWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +from vllm_flash_attn import flash_attn_varlen_func + import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, @@ -27,6 +27,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention + from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -105,37 +106,40 @@ def __init__(self, runner): @cached_property def _workspace_buffer(self): - return torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.runner.device) + return torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) @cached_property def _decode_wrapper(self): - return BatchDecodeMlaWithPagedKVCacheWrapper( - self._workspace_buffer) + return BatchDecodeMlaWithPagedKVCacheWrapper(self._workspace_buffer) @contextmanager def graph_capture(self, max_batch_size: int): - raise NotImplementedError("FlashInferMLAState does not support graph capture") + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") def graph_clone(self, batch_size: int): - raise NotImplementedError("FlashInferMLAState does not support graph capture") + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): - raise NotImplementedError("FlashInferMLAState does not support graph capture") + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError("FlashInferMLAState does not support graph capture") + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError("FlashInferMLAState does not support graph capture") + raise NotImplementedError( + "FlashInferMLAState does not support graph capture") def begin_forward(self, model_input): model_input.attn_metadata.decode_wrapper = self._decode_wrapper @@ -156,6 +160,9 @@ class FlashInferMLAMetadata(AttentionMetadata): use_cuda_graph: bool = True + # Note(simon): we are using Flash Attention for prefill so we don't need a + # wrapper. However, it can be replaced with a + # BatchPrefillWithRaggedKVCacheWrapper implementation. decode_wrapper: Optional[BatchDecodeMlaWithPagedKVCacheWrapper] = None # Metadata for the prefill stage @@ -197,6 +204,9 @@ class FlashInferMLAMetadata(AttentionMetadata): device: torch.device = torch.device("cuda") is_profile_run: bool = False + sm_scale: float = 0.0 + extras: Dict[str, torch.Tensor] = {} + def __post_init__(self): supported_head_sizes = FlashInferMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -205,45 +215,14 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") + # Note(simon): for MLA: soft max scale needs to be + # `1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)`. + assert self.head_dim is not None + self.sm_scale = 1.0 / math.sqrt(self.head_dim + self.head_dim // 8) + def begin_forward(self): if self.num_prefill_tokens > 0: - # assert NotImplementedError("FlashInferMLAState does not support prefill") - # NOTE: only in profiling, treating it as noop - return - - if self.paged_kv_indices is None: - return - - assert self.prefill_wrapper is not None - assert self.query_start_loc is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - assert self.block_table_bound is not None - assert self.seq_lens_tensor is not None - self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] - batch_size = self.query_start_loc.shape[0] - 1 - assert batch_size >= 0 - # We will use flash attention for profiling to - # determine the number of blocks. Therefore, - # we don't need to prepare the input for flashinfer for profile run. - if not self.is_profile_run: - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) - self.block_table_bound = self.block_table_bound.to(self.device) - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.plan( - self.paged_kv_indptr[:self.num_prefills + 1], - self.paged_kv_indices, - self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, - self.head_dim, - self.page_size, - sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? - data_type=self.data_type, - q_data_type=self.q_data_type) + pass if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None @@ -260,22 +239,16 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - # self.decode_wrapper.end_forward() self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], self.num_qo_heads, - # self.num_kv_heads, self.head_dim, self.page_size, - sm_scale=1.0 / math.sqrt(self.head_dim + self.head_dim//8), # TODO(simon): should we explicitly pass this in? - # Disable flashinfer's pos encoding and use vllm's rope. - # pos_encoding_mode="NONE", - # kv-cache data type. + sm_scale=self.sm_scale, data_type=self.data_type, - # query data type. q_data_type=self.q_data_type) def asdict_zerocopy(self, @@ -310,10 +283,12 @@ def advance_step(self, """ Update metadata in-place to advance one decode step. """ - raise NotImplementedError("FlashInferMLAMetadata does not support multi-step") + raise NotImplementedError( + "FlashInferMLAMetadata does not support multi-step") -class FlashInferMLAMetadataBuilder(AttentionMetadataBuilder[FlashInferMLAMetadata]): +class FlashInferMLAMetadataBuilder( + AttentionMetadataBuilder[FlashInferMLAMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] @@ -603,12 +578,14 @@ def __init__( self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype - unsupported_features = [alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap] + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] if any(unsupported_features): - raise NotImplementedError(f"FlashInferMLAImpl does not support {unsupported_features}") - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") def forward( self, @@ -626,126 +603,113 @@ def forward( "encoder/decoder cross-attention " "are not implemented for " "FlashInferMLAImpl") - key_rope = value - del value - - num_tokens, N, LR = query.shape - assert N == self.num_heads - assert LR == self.head_size + self.head_size//8 - qk_rope_head_dim = LR - self.head_size - assert qk_rope_head_dim == 64 - # hidden_size = N * self.head_size - - num_heads: int = self.num_heads - head_size: int = self.head_size - num_kv_heads: int = self.num_kv_heads - assert self.num_kv_heads == 1 - kv_cache_dtype: str = self.kv_cache_dtype - # softmax_scale: float = self.scale - # window_size = self.sliding_window - # alibi_slopes = self.alibi_slopes - # logits_soft_cap = self.logits_soft_cap - - # num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, LR) - key = key.view(-1, num_kv_heads, head_size) - key_rope = key_rope.view(-1, num_kv_heads, - head_size) # this is padded! + if attn_metadata.prefill_metadata is not None: + return self._forward_prefill(query, key, value, kv_cache, + attn_metadata, k_scale, v_scale) + + if attn_metadata.decode_metadata is not None: + return self._forward_decode(query, key, value, kv_cache, + attn_metadata, k_scale, v_scale) + + def _forward_prefill( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + + kv_a = attn_metadata.extras["kv_a"] + k_pe = attn_metadata.extras["k_pe"] + + # write the latent and rope to kv cache + # TODO(simon): remove the hard code, k_pe is assumed to be 1/8 of the + # latent size. + to_cache_key_rope = torch.nn.functional.pad( + k_pe, [0, self.head_size - self.head_size // 8], value=0) if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( - key, - key_rope, + kv_a, + to_cache_key_rope, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( - kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert num_prefill_tokens == 0 and num_decode_tokens > 0, "only mla decode" - - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert key_rope.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {key_rope.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - query = query.contiguous() # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. - # query = query[:num_prefill_tokens] - decode_query = query#[num_prefill_tokens:] - # assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens, f"{decode_query.shape=}, {num_decode_tokens=}" - - # query_nope = query[:, :, :head_size].contiguous() - # query_pe = query[:, :, head_size:].contiguous() - decode_query_nope = decode_query[:, :, :head_size].contiguous() - decode_query_pe = decode_query[:, :, head_size:].contiguous() - - # window_left = window_size[0] if window_size is not None else -1 - - prefill_output: Optional[torch.Tensor] = None + + # run prefill without paged kv cache. + q = torch.nn.functional.pad(query, [0, 256 - query.shape[-1]], value=0) + k = torch.nn.functional.pad(key, [0, 256 - key.shape[-1]], value=0) + v = torch.nn.functional.pad(value, [0, 256 - value.shape[-1]], value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + causal=True, + ) + attn_output = attn_output.view(-1, self.num_heads, + 256)[..., :value.shape[-1]].reshape( + -1, + self.num_heads * value.shape[-1]) + + return attn_output + + def _forward_decode( + self, + query: torch.Tensor, + key: torch.Tensor, + rope: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMLAMetadata, + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + assert kv_cache.numel() > 0 + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + rope, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + k_scale, + v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferMLABackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + decode_query_nope = query[:, :, :self.head_size].contiguous() + decode_query_pe = query[:, :, self.head_size:].contiguous() + decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # only in prefill, use run FA - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - causal=True, - ) - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta is not None - assert decode_meta.decode_wrapper is not None - # paged_kpe_cache, _ = kv_cache[:, 1].split( - # [qk_rope_head_dim, head_size - qk_rope_head_dim], dim=-1) - # paged_kpe_cache = paged_kpe_cache.contiguous() # this is making of entire KV cache noooo - # # note: this shouldn't matter b/c FI assumes head_dim_kpe == head_dim_ckv//8 - - decode_output = decode_meta.decode_wrapper.run( - q_nope=decode_query_nope, - q_pe=decode_query_pe, - paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=kv_cache[:, 1], - # paged_kpe_cache=paged_kpe_cache, - # sm_scale=softmax_scale, - # logits_soft_cap=logits_soft_cap, - # k_scale=k_scale, - # v_scale=v_scale, - # window_left=window_left - ) - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - assert output.shape == ( - num_tokens, N, - head_size), f"{output.shape=}!={num_tokens=}, {N=}, {head_size=}" - return output + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + # NOTE(simon): FI assumes head_dim_kpe == head_dim_ckv//8, + # and it ignores our padding for the kpe cache. + decode_output = decode_meta.decode_wrapper.run( + q_nope=decode_query_nope, + q_pe=decode_query_pe, + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=kv_cache[:, 1], + ) + + return decode_output diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 886d0ef8aae9c..2667fd1fde2ca 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -320,17 +320,15 @@ def forward( return output -from vllm.attention.backends.flash_attn import flash_attn_varlen_func, _get_query_key_seq_metadata, AttentionType -from vllm import _custom_ops as ops - - class DeepseekV2MLAAttention(nn.Module): """ - Main reference: DeepseekV2 paper, and FlashInfer Implementation https://github.com/flashinfer-ai/flashinfer/pull/551. + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: - * The key idea is to use a single latent vector to represent the entire KV cache. - * The attention should simulate a multi-head attention, while the compute is similar to multi-query attention. + * Use a single latent vector to represent the entire KV cache. + * The attention "simulates" a multi-head attention, while the compute is + similar to multi-query attention. * The dataflow is as follows, * B: batch/sequence length @@ -339,43 +337,66 @@ class DeepseekV2MLAAttention(nn.Module): * Lq: latent dimension for Q * Lkv: latent dimension for K/V * P: nope dimension, P+R is the actual head_dim in common attention. - * R: rope dimension, this slide of the head_dim goes through rotary embeddings. + * R: rope dimension, this slide of the head_dim goes through rope. * V: V head dim. # The reconstructed way, as implemented in DeepseekV2Attention: - 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and kv_latent (B, Lkv+R). - 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent and kv_a are normalized. - 3. The q_latent and kv_a are then projected up into the multi-head version. - q_latent goes from (B, Lq) to (B, N(P+R)) included the rope dimension, - which is split into q_nope (B, N, P) and q_pe (B, N, R). - kv_a goes from (B, Lkv) to (B, N(P+V)) which has the nope dimensions for K and V, - which is split into k_nope (B, N, P) and v (B, N, V). + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and + kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent + and kv_a are normalized. + 3. The q_latent and kv_a are then projected up into the multi-head + version. q_latent goes from (B, Lq) to (B, N(P+R)) included the rope + dimension, which is split into q_nope (B, N, P) and q_pe (B, N, R). + kv_a goes from (B, Lkv) to (B, N(P+V)) which has the nope dimensions + for K and V, which is split into k_nope (B, N, P) and v (B, N, V). 3. q_pe, k_pe are then passed through rotary embeddings. - 4. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from q_nope, q_pe, k_nope, k_pe. + 4. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from + q_nope, q_pe, k_nope, k_pe. 5. Attention is computued with q, k, v. - 6. The KV cache is updated with the new entries k (B, N, (P+R)) and v (B, N, V), we pad the head dim to 256 - so that the KV cache has consistent shape and works with a typical cache implementation. - 7. The attention computation returns (B, N, V), which is projected back to (B, H) using out projection. + 6. The KV cache is updated with the new entries k (B, N, (P+R)) and v + (B, N, V), we pad the head dim to 256 so that the KV cache has + consistent shape and works with a typical cache implementation. + 7. The attention computation returns (B, N, V), which is projected back + to (B, H) using out projection. # The recommended way, as described in the paper: - 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and kv_latent (B, Lkv+R). - 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent and kv_a are normalized. - 3. Here's the change, we do not perform up the full up projection for q_latent, and there is no - up projection at all for kv_a. This is achieved by the technique of "weight absorption". The paper says - "Fortunately, due to the associative law of matrix multiplication, we can absorb WUK into WUQ, and WUV into WO" - * The q up projection turns (B, Lq) into (B, N(P+R)), we split it into W_UQ (Lq, N, P) and W_QR (Lq, N, R). - * The kv_a up projection turns (B, Lkv) into (B, N(P+V)), we split it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). - * The out projection turns (B, N, V) into (B, H), has shape W_O (V, H) - * We can precompute the product of W_UQ and W_UK into W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in attention. - * We can precompute the product of W_UV and W_O into W_UV_O (N, Lkv, H), which is possible due to V@O as the "epilogue" of attention - 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. The rotary embeddingss still need to be applied to q_pe and k_pe. - 5. By applying W_UQ_UK to q_latent, we have the new q_nope of shape (B, N, Lkv). - 6. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. - 6. The attention is computed with q, k, v. Note that we just performed a MQA attention with (LKv+R) as our head dim. - 7. The KV cache is updated using the new entries k (B, N, (Lkv+R)), which included the v and rope values. - 8. The attention computation returns (B, N, Lkv), which is projected back to (B, H) using W_UV_O. - - From @tsu-bin's calculation, we only want to use the absorption technique for decode. + 1. The hidden states (B, H) are projected down into q_latent (B, Lq) and + kv_latent (B, Lkv+R). + 2. The kv_latent is split into kv_a (B, Lkv) and k_pe (B, R). q_latent + and kv_a are normalized. + 3. Here's the change, we do not perform up the full up projection for + q_latent, and there is no up projection at all for kv_a. This is + achieved by the technique of "weight absorption". The paper says + "Fortunately, due to the associative law of matrix multiplication, + we can absorb WUK into WUQ, and WUV into WO" + * The q up projection turns (B, Lq) into (B, N(P+R)), we split it + into W_UQ (Lq, N, P) and W_QR (Lq, N, R). + * The kv_a up projection turns (B, Lkv) into (B, N(P+V)), we split it + into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). + * The out projection shape W_O (V, H)turns (B, N, V) into (B, H). + * We can precompute the product of W_UQ and W_UK into + W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in + attention. + * We can precompute the product of W_UV and W_O into + W_UV_O (N, Lkv, H), which is possible due to V@O as the + "epilogue" of attention + 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. + The rotary embeddingss still need to be applied to q_pe and k_pe. + 5. By applying W_UQ_UK to q_latent, we have the new q_nope of shape + (B, N, Lkv). + 6. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, + kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. + 6. The attention is computed with q, k, v. Note that we just performed + a MQA attention with (LKv+R) as our head dim. + 7. The KV cache is updated using the new entries k (B, N, (Lkv+R)), + which included the v and rope values. + 8. The attention computation returns (B, N, Lkv), which is projected + back to (B, H) using W_UV_O. + + From @tsu-bin's calculation, we only want to use the absorption technique + for decode. The prefill algorithm should still use the up-projected MHA + for less flops and momory usage. """ def __init__( @@ -477,16 +498,6 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # The prefill attention will compute a multi-headed attention by up-projecting the latents. - # TODO(simon): enable this for prefill, and save only the latents. - # self.prefill_attn = Attention(num_heads=self.num_local_heads, - # head_size=256, - # scale=self.scaling, - # num_kv_heads=self.num_local_heads, - # cache_config=cache_config, - # quant_config=quant_config, - # prefix=f"{prefix}.prefill_attn") - # The decode attention will compute a multi-query attention by directly operating on the latent. self.attn = Attention(num_heads=self.num_local_heads, head_size=self.kv_lora_rank, scale=self.scaling, @@ -524,7 +535,6 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: # TODO(simon): support append/chunked prefill by two kernels, or using the decode kernel somehow. - if attn_metadata.prefill_metadata: return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) @@ -579,43 +589,14 @@ def forward_prefill( k[..., :self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim:] = k_pe - # write the latent and rope to kv cache - to_cache_key = kv_a.unsqueeze(1) - to_cache_key_rope = torch.nn.functional.pad( - k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) - if kv_cache.numel() > 0: - ops.reshape_and_cache_flash( - to_cache_key, - to_cache_key_rope, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype="auto", # TODO: remove hard code - k_scale=1.0, - v_scale=1.0, - ) + # HACK + attn_metadata.extras = { + "kv_a": + kv_a.unsqueeze(1), # restore the head dim to write to kv cache + "k_pe": k_pe, + } - # run the prefill kernels - q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0) - k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0) - v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0) - - prefill_meta = attn_metadata.prefill_metadata - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata( - prefill_meta, True, AttentionType.DECODER) - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - causal=True, - ) - attn_output = attn_output.view( - -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) # B(N'V) -> BH output, _ = self.o_proj(attn_output) @@ -628,8 +609,6 @@ def forward_decode( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # Let's implement the matrix absorption dataflow. - # We will start with applying the projection instead of fusing them. B = hidden_states.shape[0] # Apply UQ and QR. @@ -648,7 +627,6 @@ def forward_decode( kv_a, k_pe = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) - # print(f"{q.shape=}, {q_nope.shape=}, {q_pe.shape=}, {k_pe.shape=}, {kv_a.shape=}, {latent_cache.shape=}") k_pe = k_pe.unsqueeze(1) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) # Apply UK, q_nope (B, N, P) @ W_UK (L, N, P) -> (B, N, L) @@ -663,51 +641,14 @@ def forward_decode( q[..., self.kv_lora_rank:] = q_pe # q = q.view(B, self.num_local_heads * (self.kv_lora_rank + self.qk_rope_head_dim)) - k = kv_a + k = kv_a.unsqueeze(1) # The padding is only used for kv storage. v = torch.nn.functional.pad( - k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], - value=0).squeeze(1) + k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - # # debug: i just want to manually verify MLA is doing the right thing - # # let's get all the previous kv cache and copy them here, run the MLA manually - # paged_kv_indptr = attn_metadata.decode_metadata.paged_kv_indptr - # paged_kv_indices = attn_metadata.decode_metadata.paged_kv_indices - # paged_kv_last_page_len = attn_metadata.decode_metadata.paged_kv_last_page_len - - # # debug: we always have batch size 1 and one page - # assert paged_kv_indptr.cpu().tolist() == [ - # 0, 1 - # ], f"{paged_kv_indptr.cpu().tolist()=}" - # paged_idx = paged_kv_indices[0] - # full_latent_cache = kv_cache[paged_idx, 0] - # full_rope_cache = kv_cache[paged_idx, 1] - # # let's write k and v into the full cache at paged_kv_last_page_len-1 - # full_latent_cache[paged_kv_last_page_len - 1, :, :] = k - # full_rope_cache[paged_kv_last_page_len - 1, :, :] = v - # full_latent_cache = full_latent_cache[:paged_kv_last_page_len, :, :] - # full_rope_cache = full_rope_cache[:paged_kv_last_page_len, :, :self. - # qk_rope_head_dim] - # full_kv_cache = torch.cat([full_latent_cache, full_rope_cache], dim=-1) - - # # now let's run the MLA manually - # q_B_N_LR = q - # k_S_1_LR = full_kv_cache - # v_S_1_L = full_latent_cache - # import math - # scale = 1.0 / math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) - # attn_scores = torch.einsum("bnl,snl->nbs", q_B_N_LR, k_S_1_LR) * scale - # attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) - # attn_output_ref = torch.einsum("nbs,snl->bnl", attn_probs, v_S_1_L) - - # # assert torch.allclose(attn_output.sum(), attn_output_ref.sum()), f"{attn_output.sum()=}\n{attn_output_ref.sum()=}" - - assert attn_output.shape == ( - B, self.num_local_heads, self.kv_lora_rank - ), f"{attn_output.shape=}!={B=}, {self.num_local_heads=}, {self.kv_lora_rank=}" # idk why but the attn_output is fp32 attn_output = attn_output.to(q.dtype) # Apply UV, (B, N, L) @ W_UV (L, N, V) -> (B, N, V) From 7b92036903f97c6f2ef2adf5cfb8cdc437ea7a89 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 31 Dec 2024 19:08:17 +0000 Subject: [PATCH 08/14] working TP? Signed-off-by: simon-mo --- examples/offline_inference.py | 7 ++-- vllm/attention/backends/flashinfer_mla.py | 4 +- vllm/config.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 48 ++++++++++++----------- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 7c00321224360..5eb7a49609a15 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -12,10 +12,11 @@ # Create an LLM. llm = LLM( - model="deepseek-ai/DeepSeek-V2-Lite-Chat", - # model="deepseek-ai/DeepSeek-V2.5", tensor_parallel_size=8, + # model="deepseek-ai/DeepSeek-V2-Lite-Chat", + model="deepseek-ai/DeepSeek-V2.5", + tensor_parallel_size=8, trust_remote_code=True, - max_model_len=16384, + max_model_len=6128, # dtype="float16", enforce_eager=True, # max_num_seqs=1, diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index cf960cf56a772..74cba6d67804d 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -1,6 +1,6 @@ from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type import math from functools import cached_property @@ -205,7 +205,7 @@ class FlashInferMLAMetadata(AttentionMetadata): is_profile_run: bool = False sm_scale: float = 0.0 - extras: Dict[str, torch.Tensor] = {} + extras: Dict[str, torch.Tensor] = field(default_factory=dict) def __post_init__(self): supported_head_sizes = FlashInferMLABackend.get_supported_head_sizes() diff --git a/vllm/config.py b/vllm/config.py index 20d864869ef8d..a7e83e1f70fda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -598,11 +598,9 @@ def get_hidden_size(self) -> int: @property def is_deepseek_v2(self) -> bool: - result = hasattr( + return hasattr( self.hf_text_config, "model_type") and self.hf_text_config.model_type == 'deepseek_v2' - assert result - return result def get_head_size(self) -> int: # TODO remove hard code diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2667fd1fde2ca..efaa5ef723289 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -396,7 +396,7 @@ class DeepseekV2MLAAttention(nn.Module): From @tsu-bin's calculation, we only want to use the absorption technique for decode. The prefill algorithm should still use the up-projected MHA - for less flops and momory usage. + for less flops and memory usage. """ def __init__( @@ -436,9 +436,6 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # TODO(simon): implement matrix absorption for this, needed for deepseek v2.5 - assert q_lora_rank is None, "Currently not supported" - if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, @@ -447,20 +444,22 @@ def __init__( prefix=f"{prefix}.q_a_proj") self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + # self.q_b_proj = ReplicatedLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") else: # (H -> N(P+R)) - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear( + self.hidden_size, + # self.q_proj = ReplicatedLinear(self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") # (H -> (L+R)) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -473,17 +472,20 @@ def __init__( eps=config.rms_norm_eps) # ((L -> (N(P+V))) self.kv_b_proj = ColumnParallelLinear( + # self.kv_b_proj = ReplicatedLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") # (NV -> H) - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + # self.o_proj = ReplicatedLinear(self.num_local_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") rope_scaling["rope_type"] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, @@ -515,8 +517,8 @@ def __init__( kv_b_proj_weight = self.kv_b_proj.weight.T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim) - ), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim))}" + self.num_local_heads * (self.qk_nope_head_dim + self.v_head_dim) + ), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_local_heads * (self.qk_nope_head_dim + self.v_head_dim))}" kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_local_heads, From b8f63d7411ea53b59bd04a1cbaac5ca6358efc52 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 31 Dec 2024 19:55:05 +0000 Subject: [PATCH 09/14] feature flag working Signed-off-by: simon-mo --- examples/offline_inference.py | 3 +- vllm/config.py | 110 +++++++++++++++------- vllm/engine/arg_utils.py | 6 ++ vllm/model_executor/models/deepseek_v2.py | 12 ++- vllm/worker/cache_engine.py | 6 +- vllm/worker/model_runner.py | 2 +- 6 files changed, 94 insertions(+), 45 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 5eb7a49609a15..6f85f210b3279 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -16,11 +16,12 @@ model="deepseek-ai/DeepSeek-V2.5", tensor_parallel_size=8, trust_remote_code=True, - max_model_len=6128, + max_model_len=1024, # dtype="float16", enforce_eager=True, # max_num_seqs=1, # block_size=128, + # disable_mla=True, ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/config.py b/vllm/config.py index a7e83e1f70fda..d2a32c0206092 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -55,6 +55,19 @@ PretrainedConfig]] +def _is_flashinfer_available() -> bool: + """Check if FlashInfer is available. + + Returns: + bool: True if FlashInfer is installed and available, False otherwise. + """ + try: + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + return True + except ImportError: + return False + + class ModelConfig: """Configuration for the model. @@ -132,40 +145,43 @@ class ModelConfig: can not be gathered from the vllm arguments. override_pooling_config: Initialize non default pooling config or override default pooling config for the embedding model. + disable_mla: Whether to disable MLA for DeepSeek models. """ def __init__( - self, - model: str, - task: Union[TaskOption, _Task], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None) -> None: + self, + model: str, + task: Union[TaskOption, _Task], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + disable_mla: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -210,6 +226,7 @@ def __init__( self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init + self.disable_mla = disable_mla hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) @@ -607,9 +624,10 @@ def get_head_size(self) -> int: if self.is_deepseek_v2: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 - # return 256 - # TODO(simon): feature flag MLA - return self.hf_text_config.kv_lora_rank # + self.hf_text_config.qk_rope_head_dim + if self.should_use_mla: + return self.hf_text_config.kv_lora_rank + else: + return 256 if self.is_attention_free: return 0 @@ -668,7 +686,7 @@ def get_total_num_kv_heads(self) -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" - if self.is_deepseek_v2: + if self.should_use_mla: # TODO(simon): feature flag MLA return 1 @@ -736,6 +754,28 @@ def is_cross_encoder(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_cross_encoder_model(architectures) + @property + def should_use_mla(self) -> bool: + """Whether MLA should be used for this model. + + Returns True if: + 1. The model is DeepSeek V2 + 2. MLA is not explicitly disabled + 3. FlashInfer is available + + If conditions 1 and 2 are met but FlashInfer is not available, + logs a warning and returns False. + """ + use_mla = self.is_deepseek_v2 and not self.disable_mla + if use_mla and not _is_flashinfer_available(): + logger.warning( + "Please install or update FlashInfer for better performance on " + "DeepSeek model via enabling MLA. See " + "https://github.com/flashinfer-ai/flashinfer for installation." + ) + return False + return use_mla + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca68c1d57151c..86c6b1dcaf32c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -102,6 +102,7 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + disable_mla: bool = False # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -901,6 +902,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') + parser.add_argument('--disable-mla', + action='store_true', + help='Disable MLA for DeepSeek models.') + return parser @classmethod @@ -943,6 +948,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, + disable_mla=self.disable_mla, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index efaa5ef723289..1c917a563f3f3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, ModelConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -668,6 +668,7 @@ def __init__( self, config: PretrainedConfig, prefix: str, + model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -680,8 +681,11 @@ def __init__( # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) - # self.self_attn = DeepseekV2Attention( - self.self_attn = DeepseekV2MLAAttention( + if model_config.should_use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -757,6 +761,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -776,6 +781,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: DeepseekV2DecoderLayer( config, prefix, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, ), diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index f6e1725aa194c..e18be091c2f59 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -58,7 +58,7 @@ def __init__( cache_config.cache_dtype, self.block_size, model_config.is_attention_free, - use_mla=model_config.is_deepseek_v2) + use_mla=model_config.should_use_mla) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( @@ -111,10 +111,6 @@ def get_cache_block_size( parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size - # if model_config.is_deepseek_v2: # MLA share the K and V cache in one latent vector. - # value_cache_block = 0 - # else: - # TODO(simon): for MLA, this is repurpose for rope cache (64) but it is smaller than key cache (512). value_cache_block = key_cache_block total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9419d90aea9ae..c8071f450df95 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1060,7 +1060,7 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - use_mla=self.model_config.is_deepseek_v2, + use_mla=self.model_config.should_use_mla, ) if needs_attn_backend else None if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()( From 043085d90d97f4149dac7aaf20efa9cb9ce11c3b Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 31 Dec 2024 20:32:30 +0000 Subject: [PATCH 10/14] add env flag Signed-off-by: simon-mo --- vllm/config.py | 3 ++- vllm/envs.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index d2a32c0206092..728e1e2409ee5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -766,7 +766,8 @@ def should_use_mla(self) -> bool: If conditions 1 and 2 are met but FlashInfer is not available, logs a warning and returns False. """ - use_mla = self.is_deepseek_v2 and not self.disable_mla + use_mla = (self.is_deepseek_v2 and not self.disable_mla + and not envs.VLLM_DISABLE_MLA) if use_mla and not _is_flashinfer_available(): logger.warning( "Please install or update FlashInfer for better performance on " diff --git a/vllm/envs.py b/vllm/envs.py index c896770e5f6bc..ee15bd446056b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -286,6 +286,10 @@ def get_default_config_root(): "VLLM_FLASHINFER_FORCE_TENSOR_CORES": lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), + # If set, vLLM will disable the MLA attention optimizations. + "VLLM_DISABLE_MLA": + lambda: bool(int(os.getenv("VLLM_DISABLE_MLA", "0"))), + # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), From 6ea2c065a3666a2864e0fd8b21fce72fad046c7e Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 2 Jan 2025 18:07:21 +0000 Subject: [PATCH 11/14] lint Signed-off-by: simon-mo --- vllm/attention/backends/flashinfer_mla.py | 5 ++ vllm/config.py | 84 ++++++++++++----------- vllm/model_executor/models/deepseek_v2.py | 26 +++---- 3 files changed, 62 insertions(+), 53 deletions(-) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index 74cba6d67804d..6a2ff736d5ee7 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -597,6 +597,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -604,6 +605,10 @@ def forward( "are not implemented for " "FlashInferMLAImpl") + if output is not None: + raise NotImplementedError( + "output is not yet supported for FlashInferMLAImpl") + if attn_metadata.prefill_metadata is not None: return self._forward_prefill(query, key, value, kv_cache, attn_metadata, k_scale, v_scale) diff --git a/vllm/config.py b/vllm/config.py index 22def624a9b2d..df2a2d3843ecb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -82,11 +82,12 @@ def _is_flashinfer_available() -> bool: bool: True if FlashInfer is installed and available, False otherwise. """ try: - from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper # noqa:F401 return True except ImportError: return False + class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -184,6 +185,7 @@ class ModelConfig: generation_config: Configuration parameter file for generation. disable_mla: Whether to disable MLA for DeepSeek models. """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -208,42 +210,43 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__(self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None, - disable_mla: bool = False, - ) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + disable_mla: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -742,10 +745,9 @@ def get_hidden_size(self) -> int: @property def is_deepseek_v2(self) -> bool: - return hasattr( - self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - in ('deepseek_v2', 'deepseek_v3')) + return hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + in ('deepseek_v2', 'deepseek_v3')) def get_head_size(self) -> int: # TODO remove hard code diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1c917a563f3f3..09afd50d5a1e8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -479,13 +479,11 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") # (NV -> H) - self.o_proj = RowParallelLinear( - self.num_heads * self.v_head_dim, - # self.o_proj = ReplicatedLinear(self.num_local_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") rope_scaling["rope_type"] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, @@ -516,9 +514,13 @@ def __init__( kv_b_proj_weight = self.kv_b_proj.weight.T assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_local_heads * (self.qk_nope_head_dim + self.v_head_dim) - ), f"{kv_b_proj_weight.shape} != {(self.kv_lora_rank, self.num_local_heads * (self.qk_nope_head_dim + self.v_head_dim))}" + self.kv_lora_rank, self.num_local_heads * + (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_local_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_local_heads, @@ -536,7 +538,8 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # TODO(simon): support append/chunked prefill by two kernels, or using the decode kernel somehow. + # TODO(simon): support append/chunked prefill by two kernels, + # or using the decode kernel somehow. if attn_metadata.prefill_metadata: return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) @@ -641,7 +644,6 @@ def forward_decode( device=q.device) q[..., :self.kv_lora_rank] = q_nope q[..., self.kv_lora_rank:] = q_pe - # q = q.view(B, self.num_local_heads * (self.kv_lora_rank + self.qk_rope_head_dim)) k = kv_a.unsqueeze(1) # The padding is only used for kv storage. From f668cb986c60beab3f45a62cf158ecd3b1854b06 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 6 Jan 2025 17:35:04 +0000 Subject: [PATCH 12/14] more debug lines Signed-off-by: simon-mo --- examples/offline_inference.py | 20 ++- vllm/attention/backends/flashinfer_mla.py | 159 +++++++++++++++++++++- vllm/model_executor/models/deepseek_v2.py | 78 +++++------ 3 files changed, 202 insertions(+), 55 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 6f85f210b3279..db04d50a857ae 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -6,17 +6,22 @@ "The president of the United States is", "The capital of France is", "The future of AI is", + # "Milly needs to return a book she decided was really boring. The book weighs 4 pounds, cost $32, and needs to be returned to a distribution center 20 miles away. If the shipping company charges $0.35 per pound plus $0.08 per mile, and Amazon will only refund 75% of the book's purchase price, how much money will Milly lose?" ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=512, +) # Create an LLM. llm = LLM( - # model="deepseek-ai/DeepSeek-V2-Lite-Chat", - model="deepseek-ai/DeepSeek-V2.5", - tensor_parallel_size=8, + model="deepseek-ai/DeepSeek-V2-Lite-Chat", + # model="deepseek-ai/DeepSeek-V2.5", + tensor_parallel_size=1, trust_remote_code=True, - max_model_len=1024, + max_model_len=4096, # dtype="float16", enforce_eager=True, # max_num_seqs=1, @@ -31,3 +36,8 @@ prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + # print the prompt token ids and output token ids + tokens = list(output.prompt_token_ids) + list(output.outputs[0].token_ids) + # print in block of 8 + for i in range(0, len(tokens), 8): + print(tokens[i:i + 8]) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index 6a2ff736d5ee7..e69bb49afbca6 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -222,7 +222,7 @@ def __post_init__(self): def begin_forward(self): if self.num_prefill_tokens > 0: - pass + return if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None @@ -634,6 +634,7 @@ def _forward_prefill( # write the latent and rope to kv cache # TODO(simon): remove the hard code, k_pe is assumed to be 1/8 of the # latent size. + assert k_pe.shape[-1] == self.head_size // 8 to_cache_key_rope = torch.nn.functional.pad( k_pe, [0, self.head_size - self.head_size // 8], value=0) if kv_cache.numel() > 0: @@ -667,7 +668,6 @@ def _forward_prefill( 256)[..., :value.shape[-1]].reshape( -1, self.num_heads * value.shape[-1]) - return attn_output def _forward_decode( @@ -683,8 +683,8 @@ def _forward_decode( assert kv_cache.numel() > 0 # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( - key, - rope, + key.contiguous(), + rope.contiguous(), kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), @@ -708,13 +708,160 @@ def _forward_decode( assert decode_meta is not None assert decode_meta.decode_wrapper is not None + paged_kpe_cache = kv_cache[:, 1] + paged_kpe_cache = paged_kpe_cache[..., :64].contiguous() + # NOTE(simon): FI assumes head_dim_kpe == head_dim_ckv//8, # and it ignores our padding for the kpe cache. + + # print( + # f"{decode_query_nope.shape=}, {decode_query_pe.shape=}, {kv_cache[:, 0].shape=}, {paged_kpe_cache.shape=}" + # ) + decode_output = decode_meta.decode_wrapper.run( q_nope=decode_query_nope, q_pe=decode_query_pe, - paged_ckv_cache=kv_cache[:, 0], - paged_kpe_cache=kv_cache[:, 1], + paged_ckv_cache=kv_cache[:, 0].squeeze(), + # paged_kpe_cache=kv_cache[:, 1], + paged_kpe_cache=paged_kpe_cache.squeeze(), ) + # load cache + paged_kv_indptr = decode_meta.paged_kv_indptr + paged_kv_indices = decode_meta.paged_kv_indices + paged_kv_last_page_len = decode_meta.paged_kv_last_page_len + + def gather_paged_kv( + kv_cache: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + ): + """ + kv_cache: shape (num_blocks, 2, block_size, num_heads, head_dim) + paged_kv_indices: shape [total_blocks_across_batch] + paged_kv_indptr: shape [batch_size + 1] + paged_kv_last_page_len: shape [batch_size] + + Returns: + K_out, V_out with shape (batch_size, max_kv_len, num_heads, head_dim) + """ + num_blocks, two_, block_size, num_heads, head_dim = kv_cache.shape + assert two_ == 2, "kv_cache shape must be (num_blocks, 2, block_size, num_heads, head_dim)" + + batch_size = paged_kv_indptr.shape[0] - 1 + device = kv_cache.device + dtype = kv_cache.dtype + + # ------------------------------------------------------------------------- + # 1. Compute the maximum number of tokens (max_kv_len) across all requests + # ------------------------------------------------------------------------- + max_kv_len = 0 + for b in range(batch_size): + # The block indices for request b + start = paged_kv_indptr[b] + end = paged_kv_indptr[b + 1] + num_full_blocks = (end - start) - 1 # all but the last block + total_tokens = num_full_blocks * block_size + paged_kv_last_page_len[ + b] + max_kv_len = max(max_kv_len, total_tokens) + + # ------------------------------------------------------------------------- + # 2. Allocate the output buffers for K and V + # Shape: (batch_size, max_kv_len, num_heads, head_dim) + # ------------------------------------------------------------------------- + K_out = torch.zeros( + (batch_size, max_kv_len, num_heads, head_dim), + device=device, + dtype=dtype, + ) + V_out = torch.zeros_like(K_out) # same shape & dtype as K_out + + # ------------------------------------------------------------------------- + # 3. Copy each request’s blocks from kv_cache into [K_out, V_out] + # ------------------------------------------------------------------------- + for b in range(batch_size): + start = paged_kv_indptr[b] + end = paged_kv_indptr[b + 1] + block_indices_for_b = paged_kv_indices[start:end] + + # We'll copy blocks sequentially into K_out[b, ...], V_out[b, ...] + copy_pos = 0 + num_blocks_b = len(block_indices_for_b) + + # Go through each block index + for i, block_idx in enumerate(block_indices_for_b): + # For all but the last block, copy the entire block_size. + # For the last block, only copy 'paged_kv_last_page_len[b]' entries + if i < (num_blocks_b - 1): + # Copy entire block + K_block = kv_cache[ + block_idx, + 0] # shape (block_size, num_heads, head_dim) + V_block = kv_cache[block_idx, 1] + K_out[b, copy_pos:copy_pos + block_size] = K_block + V_out[b, copy_pos:copy_pos + block_size] = V_block + copy_pos += block_size + else: + # Last block for this request + last_len = paged_kv_last_page_len[b].item() + if last_len > 0: + K_block = kv_cache[ + block_idx, + 0][: + last_len] # shape (last_len, num_heads, head_dim) + V_block = kv_cache[block_idx, 1][:last_len] + K_out[b, copy_pos:copy_pos + last_len] = K_block + V_out[b, copy_pos:copy_pos + last_len] = V_block + # If last_len == 0, we simply skip copying + copy_pos += last_len + + return K_out, V_out + + debug = False + if debug: + K_out, V_out = gather_paged_kv(kv_cache, paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len) + + # debug: hand implemented MLA, this not correct yet, please fix it + q_pe = decode_query_pe # [bsz, num_heads, qk_rope_head_dim] + k_pe_cache = V_out[:, :, 0, :self.head_size // + 8] # [bsz, kv_len, rope_head_dim] + + attn_weights_pe = torch.matmul( + q_pe, # [bsz, num_heads, qk_rope_head_dim] + k_pe_cache.transpose( + 1, 2 + ) # [bsz, kv_len, 64] view(bsz, kv_len, self.qk_rope_head_dim) + ) + + q_nope = decode_query_nope # [bsz, num_heads, latent_dim] + compressed_kv_normed_cache = K_out.squeeze( + 2) # [bsz, kv_len, latent_dim] + + # attn_weights_nope ~ [bsz, num_heads, kv_len] + attn_weights_nope = torch.matmul( + q_nope, # [bsz, 128, 512] + compressed_kv_normed_cache.transpose( + 1, 2) # view(bsz, kv_len, 512) + ) + + attn_weights = (attn_weights_pe + attn_weights_nope) * self.scale + + attn_weights = torch.nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + q_nope.dtype) + + # attn_output ~ {attn_output.shape}") # [bsz, 128, 512] + attn_output = torch.matmul( + attn_weights, # [bsz, 128, kv_len] + compressed_kv_normed_cache # [bsz, kv_len, 512] + ) + + return attn_output + + # diff = attn_output - decode_output + # print(f"diff: {diff.abs().sum()}") return decode_output diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 09afd50d5a1e8..c2c024e398dce 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -417,20 +417,19 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - # Note(simon): Added some symbols for shapes, hoping to help clarity. - self.hidden_size = hidden_size # H - self.qk_nope_head_dim = qk_nope_head_dim # P - self.qk_rope_head_dim = qk_rope_head_dim # R - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # P + R - self.v_head_dim = v_head_dim # V + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank # L + self.kv_lora_rank = kv_lora_rank - self.num_heads = num_heads # N + self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size # N' + self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta @@ -444,24 +443,20 @@ def __init__( prefix=f"{prefix}.q_a_proj") self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - # self.q_b_proj = ReplicatedLinear(q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") else: - # (H -> N(P+R)) - self.q_proj = ColumnParallelLinear( - self.hidden_size, - # self.q_proj = ReplicatedLinear(self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") - # (H -> (L+R)) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, @@ -470,15 +465,12 @@ def __init__( prefix=f"{prefix}.kv_a_proj_with_mqa") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - # ((L -> (N(P+V))) self.kv_b_proj = ColumnParallelLinear( - # self.kv_b_proj = ReplicatedLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - # (NV -> H) self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, @@ -531,6 +523,9 @@ def __init__( # self.W_UK = self.W_UK.view(self.kv_lora_rank, self.num_local_heads * self.qk_nope_head_dim) # self.W_UV = self.W_UV.view(self.kv_lora_rank, self.num_local_heads * self.v_head_dim) + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + def forward( self, positions: torch.Tensor, @@ -540,6 +535,9 @@ def forward( ) -> torch.Tensor: # TODO(simon): support append/chunked prefill by two kernels, # or using the decode kernel somehow. + if attn_metadata.prefill_metadata and attn_metadata.decode_metadata: + raise ValueError( + "Chunked prefill is not supported when MLA is enabled.") if attn_metadata.prefill_metadata: return self.forward_prefill(positions, hidden_states, kv_cache, attn_metadata) @@ -554,7 +552,6 @@ def forward_prefill( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # BH -> B(N(P+R)) -> BN(P+R) if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) @@ -564,37 +561,27 @@ def forward_prefill( q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) - # BN(P+R) -> BNP, BNR - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) - # BH -> B(L+R) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - # B(L+R) -> BL, BR kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - # B(L+R) -> B1(L+R) latent_cache = latent_cache.unsqueeze(1) - # BL -> BL kv_a = self.kv_a_layernorm(kv_a.contiguous()) - # BL -> B(N'(P+V)) kv = self.kv_b_proj(kv_a)[0] - # B(N'(P+V)) -> BN'(P+V) kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - # BN'(P+V) -> BN'P, BN'V k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # B1(L+R) -> B1R k_pe = latent_cache[:, :, self.kv_lora_rank:] - # BNR, B1R -> BNR, B1R q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - # BN(P+R) q[..., self.qk_nope_head_dim:] = q_pe - # BN(P+R) k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim:] = k_pe - # HACK + # HACK(simon): these need to be passed into the attention backend + # to write to the kv cache. + # TODO(simon): do we need to free these? attn_metadata.extras = { "kv_a": kv_a.unsqueeze(1), # restore the head dim to write to kv cache @@ -628,11 +615,13 @@ def forward_decode( q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, k_pe = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) k_pe = k_pe.unsqueeze(1) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) # Apply UK, q_nope (B, N, P) @ W_UK (L, N, P) -> (B, N, L) q_nope = torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK) @@ -651,6 +640,7 @@ def forward_decode( k_pe, [0, self.kv_lora_rank - self.qk_rope_head_dim], value=0) assert k.numel() == v.numel(), f"{k.numel()=} != {v.numel()=}" + attn_metadata.debug_layer_idx = self.debug_layer_idx attn_output = self.attn(q, k, v, kv_cache, attn_metadata) # idk why but the attn_output is fp32 From 4094e2d1c32ca2fa319529bd87643def509b15b4 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 16 Jan 2025 21:17:42 +0000 Subject: [PATCH 13/14] Use yarn scales, fixes some accuracy issue Co-authored-by: cennn <2523403608@qq.com> Signed-off-by: simon-mo --- .../run-lm-eval-gsm-vllm-baseline.sh | 12 +- examples/offline_inference.py | 6 - vllm/attention/backends/flashinfer_mla.py | 160 +----------------- 3 files changed, 14 insertions(+), 164 deletions(-) diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 65be3c5d93b20..380ad0cd1903b 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -23,22 +23,22 @@ usage() { while getopts "m:b:l:f:t:" OPT; do case ${OPT} in - m ) + m ) MODEL="$OPTARG" ;; - b ) + b ) BATCH_SIZE="$OPTARG" ;; - l ) + l ) LIMIT="$OPTARG" ;; - f ) + f ) FEWSHOT="$OPTARG" ;; t ) TP_SIZE="$OPTARG" ;; - \? ) + \? ) usage exit 1 ;; @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096,enforce_eager=true" \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ --batch_size "$BATCH_SIZE" diff --git a/examples/offline_inference.py b/examples/offline_inference.py index db04d50a857ae..cb5ee8d72c836 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -6,7 +6,6 @@ "The president of the United States is", "The capital of France is", "The future of AI is", - # "Milly needs to return a book she decided was really boring. The book weighs 4 pounds, cost $32, and needs to be returned to a distribution center 20 miles away. If the shipping company charges $0.35 per pound plus $0.08 per mile, and Amazon will only refund 75% of the book's purchase price, how much money will Milly lose?" ] # Create a sampling params object. sampling_params = SamplingParams( @@ -36,8 +35,3 @@ prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # print the prompt token ids and output token ids - tokens = list(output.prompt_token_ids) + list(output.outputs[0].token_ids) - # print in block of 8 - for i in range(0, len(tokens), 8): - print(tokens[i:i + 8]) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index e69bb49afbca6..316944b8e5663 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -662,6 +662,7 @@ def _forward_prefill( cu_seqlens_k=attn_metadata.seq_start_loc, max_seqlen_q=attn_metadata.max_prefill_seq_len, max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=self.scale, causal=True, ) attn_output = attn_output.view(-1, self.num_heads, @@ -681,6 +682,7 @@ def _forward_decode( v_scale: float, ) -> torch.Tensor: assert kv_cache.numel() > 0 + # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( key.contiguous(), @@ -692,6 +694,7 @@ def _forward_decode( k_scale, v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 if self.kv_cache_dtype.startswith("fp8"): @@ -702,166 +705,19 @@ def _forward_decode( decode_query_nope = query[:, :, :self.head_size].contiguous() decode_query_pe = query[:, :, self.head_size:].contiguous() - decode_output: Optional[torch.Tensor] = None - decode_meta = attn_metadata.decode_metadata assert decode_meta is not None assert decode_meta.decode_wrapper is not None paged_kpe_cache = kv_cache[:, 1] - paged_kpe_cache = paged_kpe_cache[..., :64].contiguous() - - # NOTE(simon): FI assumes head_dim_kpe == head_dim_ckv//8, - # and it ignores our padding for the kpe cache. - - # print( - # f"{decode_query_nope.shape=}, {decode_query_pe.shape=}, {kv_cache[:, 0].shape=}, {paged_kpe_cache.shape=}" - # ) + paged_kpe_cache = paged_kpe_cache[..., :64] + decode_meta.decode_wrapper._sm_scale = self.scale decode_output = decode_meta.decode_wrapper.run( q_nope=decode_query_nope, q_pe=decode_query_pe, - paged_ckv_cache=kv_cache[:, 0].squeeze(), - # paged_kpe_cache=kv_cache[:, 1], - paged_kpe_cache=paged_kpe_cache.squeeze(), + paged_ckv_cache=kv_cache[:, 0], + paged_kpe_cache=kv_cache[:, 1], ) - - # load cache - paged_kv_indptr = decode_meta.paged_kv_indptr - paged_kv_indices = decode_meta.paged_kv_indices - paged_kv_last_page_len = decode_meta.paged_kv_last_page_len - - def gather_paged_kv( - kv_cache: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - ): - """ - kv_cache: shape (num_blocks, 2, block_size, num_heads, head_dim) - paged_kv_indices: shape [total_blocks_across_batch] - paged_kv_indptr: shape [batch_size + 1] - paged_kv_last_page_len: shape [batch_size] - - Returns: - K_out, V_out with shape (batch_size, max_kv_len, num_heads, head_dim) - """ - num_blocks, two_, block_size, num_heads, head_dim = kv_cache.shape - assert two_ == 2, "kv_cache shape must be (num_blocks, 2, block_size, num_heads, head_dim)" - - batch_size = paged_kv_indptr.shape[0] - 1 - device = kv_cache.device - dtype = kv_cache.dtype - - # ------------------------------------------------------------------------- - # 1. Compute the maximum number of tokens (max_kv_len) across all requests - # ------------------------------------------------------------------------- - max_kv_len = 0 - for b in range(batch_size): - # The block indices for request b - start = paged_kv_indptr[b] - end = paged_kv_indptr[b + 1] - num_full_blocks = (end - start) - 1 # all but the last block - total_tokens = num_full_blocks * block_size + paged_kv_last_page_len[ - b] - max_kv_len = max(max_kv_len, total_tokens) - - # ------------------------------------------------------------------------- - # 2. Allocate the output buffers for K and V - # Shape: (batch_size, max_kv_len, num_heads, head_dim) - # ------------------------------------------------------------------------- - K_out = torch.zeros( - (batch_size, max_kv_len, num_heads, head_dim), - device=device, - dtype=dtype, - ) - V_out = torch.zeros_like(K_out) # same shape & dtype as K_out - - # ------------------------------------------------------------------------- - # 3. Copy each request’s blocks from kv_cache into [K_out, V_out] - # ------------------------------------------------------------------------- - for b in range(batch_size): - start = paged_kv_indptr[b] - end = paged_kv_indptr[b + 1] - block_indices_for_b = paged_kv_indices[start:end] - - # We'll copy blocks sequentially into K_out[b, ...], V_out[b, ...] - copy_pos = 0 - num_blocks_b = len(block_indices_for_b) - - # Go through each block index - for i, block_idx in enumerate(block_indices_for_b): - # For all but the last block, copy the entire block_size. - # For the last block, only copy 'paged_kv_last_page_len[b]' entries - if i < (num_blocks_b - 1): - # Copy entire block - K_block = kv_cache[ - block_idx, - 0] # shape (block_size, num_heads, head_dim) - V_block = kv_cache[block_idx, 1] - K_out[b, copy_pos:copy_pos + block_size] = K_block - V_out[b, copy_pos:copy_pos + block_size] = V_block - copy_pos += block_size - else: - # Last block for this request - last_len = paged_kv_last_page_len[b].item() - if last_len > 0: - K_block = kv_cache[ - block_idx, - 0][: - last_len] # shape (last_len, num_heads, head_dim) - V_block = kv_cache[block_idx, 1][:last_len] - K_out[b, copy_pos:copy_pos + last_len] = K_block - V_out[b, copy_pos:copy_pos + last_len] = V_block - # If last_len == 0, we simply skip copying - copy_pos += last_len - - return K_out, V_out - - debug = False - if debug: - K_out, V_out = gather_paged_kv(kv_cache, paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len) - - # debug: hand implemented MLA, this not correct yet, please fix it - q_pe = decode_query_pe # [bsz, num_heads, qk_rope_head_dim] - k_pe_cache = V_out[:, :, 0, :self.head_size // - 8] # [bsz, kv_len, rope_head_dim] - - attn_weights_pe = torch.matmul( - q_pe, # [bsz, num_heads, qk_rope_head_dim] - k_pe_cache.transpose( - 1, 2 - ) # [bsz, kv_len, 64] view(bsz, kv_len, self.qk_rope_head_dim) - ) - - q_nope = decode_query_nope # [bsz, num_heads, latent_dim] - compressed_kv_normed_cache = K_out.squeeze( - 2) # [bsz, kv_len, latent_dim] - - # attn_weights_nope ~ [bsz, num_heads, kv_len] - attn_weights_nope = torch.matmul( - q_nope, # [bsz, 128, 512] - compressed_kv_normed_cache.transpose( - 1, 2) # view(bsz, kv_len, 512) - ) - - attn_weights = (attn_weights_pe + attn_weights_nope) * self.scale - - attn_weights = torch.nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to( - q_nope.dtype) - - # attn_output ~ {attn_output.shape}") # [bsz, 128, 512] - attn_output = torch.matmul( - attn_weights, # [bsz, 128, kv_len] - compressed_kv_normed_cache # [bsz, kv_len, 512] - ) - - return attn_output - - # diff = attn_output - decode_output - # print(f"diff: {diff.abs().sum()}") return decode_output + From a04777276ced77e1bc392e7a6e3cba61af42f65b Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 16 Jan 2025 21:20:51 +0000 Subject: [PATCH 14/14] format Signed-off-by: simon-mo --- vllm/attention/backends/flashinfer_mla.py | 1 - vllm/model_executor/models/deepseek_v2.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/vllm/attention/backends/flashinfer_mla.py b/vllm/attention/backends/flashinfer_mla.py index 316944b8e5663..4bc60f6d4636f 100644 --- a/vllm/attention/backends/flashinfer_mla.py +++ b/vllm/attention/backends/flashinfer_mla.py @@ -720,4 +720,3 @@ def _forward_decode( paged_kpe_cache=kv_cache[:, 1], ) return decode_output - diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c2c024e398dce..407507485e9d4 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -520,8 +520,6 @@ def __init__( ) self.W_UK, self.W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # self.W_UK = self.W_UK.view(self.kv_lora_rank, self.num_local_heads * self.qk_nope_head_dim) - # self.W_UV = self.W_UV.view(self.kv_lora_rank, self.num_local_heads * self.v_head_dim) self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2])