Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Deepseek V2 MLA #10927

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

[WIP] Deepseek V2 MLA #10927

wants to merge 16 commits into from

Conversation

simon-mo
Copy link
Collaborator

@simon-mo simon-mo commented Dec 5, 2024

Status (12/05/2024):

Currently, I have implemented MLA in KV cache format and utilized FlashInfer's MLA decode kernel, with correct output. The throughput for a sample case already goes from 10.47 rps to 18.5 rps . The PR is still very messy and lack proper design but we demonstrated space savings and speed up.

Before

$ VLLM_ATTENTION_BACKEND=FLASHINFER CUDA_VISIBLE_DEVICES=2 python benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite-Chat --trust-remote-code --enforce-eager --max-model-len 8192 --input-len 1000 --output-len 100 --num-prompts 200 --dtype float16
...
INFO 12-06 07:43:15 model_runner.py:1105] Loading model weights took 29.3010 GB
WARNING 12-06 07:43:15 fused_moe.py:326] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/simonmo/vllm/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_H100_80GB_HBM3.json
INFO 12-06 07:43:16 worker.py:235] Memory profiling results: duration=1.09 seconds, total_gpu_memory=79.10GiB, initial_memory_usage=34.44GiB, peak_torch_memory=30.56GiB, memory_usage_post_profile=34.81GiB, non_torch_memory=5.21GiB, kv_cache_size=35.42GiB, gpu_memory_utilization=0.90.
INFO 12-06 07:43:16 gpu_executor.py:76] # GPU blocks: 5373, # CPU blocks: 606
INFO 12-06 07:43:16 gpu_executor.py:80] Maximum concurrency for 8192 tokens per request: 10.49x
Processed prompts:   0%|                                                                                                               | 0/200 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]WARNING 12-06 07:43:21 scheduler.py:1536] Sequence group 83 is preempted by PreemptionMode.RECOMPUTE mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=1
INFO 12-06 07:43:23 metrics.py:460] Avg prompt throughput: 16771.0 tokens/s, Avg generation throughput: 823.6 tokens/s, Running: 81 reqs, Swapped: 0 reqs, Pending: 119 reqs, GPU KV cache usage: 99.5%, CPU KV cache usage: 0.0%.
Processed prompts:  39%|█████████████████████████████████████▊                                                           | 78/200 [00:09<00:11, 11.00it/s, est. speed input: 8544.12 toks/s, output: 854.41 toks/s]INFO 12-06 07:43:28 metrics.py:460] Avg prompt throughput: 16982.3 tokens/s, Avg generation throughput: 1133.8 tokens/s, Running: 82 reqs, Swapped: 0 reqs, Pending: 39 reqs, GPU KV cache usage: 98.0%, CPU KV cache usage: 0.0%.
Processed prompts:  80%|███████████████████████████████████████████████████████████████████████████▋                  | 161/200 [00:14<00:01, 24.54it/s, est. speed input: 11268.57 toks/s, output: 1126.86 toks/s]INFO 12-06 07:43:33 metrics.py:460] Avg prompt throughput: 7962.9 tokens/s, Avg generation throughput: 1323.7 tokens/s, Running: 39 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 46.5%, CPU KV cache usage: 0.0%.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:18<00:00, 10.70it/s, est. speed input: 10697.81 toks/s, output: 1069.78 toks/s]
Throughput: 10.47 requests/s, 11519.10 total tokens/s, 1047.19 output tokens/s
$ VLLM_ATTENTION_BACKEND=FLASHINFER CUDA_VISIBLE_DEVICES=2 python benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite-Chat --trust-remote-code --enforce-eager --max-model-len 8192 --input-len 1000 --output-len 100 --num-prompts 200 --dtype float16
INFO 12-06 07:38:35 model_runner.py:1105] Loading model weights took 29.3010 GB
WARNING 12-06 07:38:36 fused_moe.py:326] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/simonmo/vllm/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_H100_80GB_HBM3.json
INFO 12-06 07:38:36 worker.py:235] Memory profiling results: duration=0.93 seconds, total_gpu_memory=79.10GiB, initial_memory_usage=34.44GiB, peak_torch_memory=30.56GiB, memory_usage_post_profile=34.81GiB, non_torch_memory=5.21GiB, kv_cache_size=35.42GiB, gpu_memory_utilization=0.90.
INFO 12-06 07:38:36 gpu_executor.py:76] # GPU blocks: 42987, # CPU blocks: 4854
INFO 12-06 07:38:36 gpu_executor.py:80] Maximum concurrency for 8192 tokens per request: 83.96x
Processed prompts:   0%|                                                                                                               | 0/200 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]2024-12-06 07:38:40,331 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
INFO 12-06 07:38:44 metrics.py:460] Avg prompt throughput: 31760.2 tokens/s, Avg generation throughput: 31.8 tokens/s, Running: 160 reqs, Swapped: 0 reqs, Pending: 40 reqs, GPU KV cache usage: 23.4%, CPU KV cache usage: 0.0%.
INFO 12-06 07:38:49 metrics.py:460] Avg prompt throughput: 7974.3 tokens/s, Avg generation throughput: 3317.3 tokens/s, Running: 200 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 31.6%, CPU KV cache usage: 0.0%.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:10<00:00, 19.21it/s, est. speed input: 19206.17 toks/s, output: 1920.62 toks/s]
Throughput: 18.50 requests/s, 20347.06 total tokens/s, 1849.73 output tokens/s

Some todos:

  • Document the strategy we are using and follow up works. Currently there's still KV cache waste but i think it is the best we can do until a hybrid cache allocator.
  • Design FLASHINFER_MLA backend, and feature flag MLA (hopefully on by default).
  • Misc: implement q_lora, cache the mat absorb matrices.
  • Figure out CUDA graph issue. Will just opt it out for now (also turn off chunked prefill)
  • Figure out the TP story.
  • Feature flag --disable-mla and DISABLE_MLA
  • Benchmark.
  • Support deepseek V3

Some out of scope:

  • No prefill decode selectors in the model files
  • Test piece wise cuda graph in V1
  • Support chunked prefill

Copy link

github-actions bot commented Dec 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

return self.forward_decode(positions, hidden_states, kv_cache,
attn_metadata)

def forward_prefill(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does flashinfer have prefill kernel?

@simon-mo simon-mo mentioned this pull request Dec 27, 2024
10 tasks
@liangzelang
Copy link

Nice job! And I wonder how do you to solve MLA prefill kernel because there is no avaiable MLA prefill kernel but only decode kernel in flashinfer library.

@simon-mo
Copy link
Collaborator Author

@liangzelang this PR will perform the regular up projection to turn MLA into MHA for prefill.

Copy link

mergify bot commented Dec 31, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @simon-mo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 31, 2024
@simon-mo
Copy link
Collaborator Author

simon-mo commented Jan 2, 2025

Update:

  • Debugging some accuracy issues today (gsm8k is worse in MLA, TP might be a factor as well)
  • The next remaining steps after debugging finishes
    • Implement matrix absorption
    • Final round of benchmarks
    • Deepseek V3

Then it will be ready for review

Signed-off-by: simon-mo <[email protected]>
@mergify mergify bot removed the needs-rebase label Jan 6, 2025
@cennn
Copy link
Contributor

cennn commented Jan 13, 2025

I think the accuracy issues of the FlashInfer kernel might be related to this:
simon-mo/flashinfer#1 (comment)

simon-mo and others added 2 commits January 16, 2025 21:17
Signed-off-by: simon-mo <[email protected]>
Copy link

mergify bot commented Jan 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @simon-mo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Jan 16, 2025
@simon-mo
Copy link
Collaborator Author

Thanks to @cennn, the accuracy issue has been partially identified. We are now at a point the kernel generate coherent output. However, the accuracy is still lower than that of MHA implementation.

MLA
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  | 0.24|±  |0.0429|
|     |       |strict-match    |     8|exact_match|↑  | 0.23|±  |0.0423|


No MLA
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  | 0.32|±  |0.0469|
|     |       |strict-match    |     8|exact_match|↑  | 0.32|±  |0.0469|

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants