Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: decoding speed on long context #11286

Open
1 task done
155394551lzk opened this issue Dec 18, 2024 · 43 comments
Open
1 task done

[Performance]: decoding speed on long context #11286

155394551lzk opened this issue Dec 18, 2024 · 43 comments
Labels
performance Performance-related issues

Comments

@155394551lzk
Copy link

155394551lzk commented Dec 18, 2024

Proposal to improve performance

In our experiments, we found that the decoding speed of vLLM decreases dramatically when the length of the prompt becomes longer.
We fixed the batchsize=90 the decoding speed is 5364 tokens/s when the length of the prompt is within 100, 5500 tokens/s when 100 to 200, and decreases to 782 when 4000 to 8000, and decreases to 273 when greater than 8000.

prompt length 0-100 100-200 200-500 500-1000 1000-2000 2000-4000 4000-8000 8000+
words/s 5364 5500 4722 2815 2484 1627 782 273
GPU is single A800, 80G, vLLM block_size=16, max_num_seqs=512, max_model_len=8192, max_tokens=200. Is that why page attention is accessed more often?

Report of performance regression

No response

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@155394551lzk 155394551lzk added the performance Performance-related issues label Dec 18, 2024
@noooop
Copy link
Contributor

noooop commented Dec 18, 2024

Most of the time, the GPU meet bandwidth bottleneck rather than computing bottleneck,
So the inference speed depends on the memory bandwidth.

As the prompt length increases, the size of the kv cache that needs to be read even exceeds the size of the model,
This is why inference speed is decrease.

Go further:

Inference latency increases linearly to the context size, primarily due to the time needed to access
cached tokens

You can even fit a linear function, x=kv cache size need read, y=time required for one step

see more flashdecoding

@Flynn-Zh
Copy link

GPU L40, Qwen2.5-32B-GPTQ-Int4
same question, prompt has 9k words, vllm takes 12 seconds, sglang 8 seconds.
Is there any configuration that can improve performance?

@noooop
Copy link
Contributor

noooop commented Dec 19, 2024

@Flynn-Zh

By default

  • vllm use flash_attn for decoding
  • sglang use flashinfer for decoding

flashinfer was slightly faster than flash_attn, but I'm not sure if that's still the case

You can try vllm + flashinfer see if that improve performance.

Looking forward to your benchmark

@Flynn-Zh
Copy link

Flynn-Zh commented Dec 19, 2024

@noooop

i can't found the configuration to use flashinfer, how to use flashinfer in vllm?

@noooop
Copy link
Contributor

noooop commented Dec 19, 2024

  1. install flashinfer

I'm not sure vllm supports the latest released flashinfer v0.2.0 #11314

It is safer to use flashinfer v0.1.6

  1. set environment variable VLLM_ATTENTION_BACKEND=FLASHINFER, enable flashinfer

@Flynn-Zh
Copy link

use flashinfer still 12 seconds

@noooop
Copy link
Contributor

noooop commented Dec 19, 2024

use flashinfer still 12 seconds

interesting

@jeejeelee
Copy link
Collaborator

Maybe it is a similar issue with #11317 (comment)

@Flynn-Zh
Copy link

@jeejeelee i try increase max-seq-len-to-capture,but it's useless

@noooop
Copy link
Contributor

noooop commented Dec 19, 2024

@Flynn-Zh

vllm v0 use default scheduler,chunked_prefill performs better for long inputs

Please try the configuration below:

enable_chunked_prefill = True
max_num_seqs=32
max_num_batched_tokens=2048 <- 2048 token can generally make the GPU reach saturation

@Flynn-Zh
Copy link

@noooop
I've also tried it,it's useless

@jeejeelee
Copy link
Collaborator

jeejeelee commented Dec 19, 2024

@jeejeelee i try increase max-seq-len-to-capture,but it's useless

Could you plz provide the more details, such as model ,running script, etc. I can try reproduce your issue if I have bandwith this weekend.

@Flynn-Zh
Copy link

@jeejeelee

GPU L40, Qwen2.5-32B-GPTQ-Int4 same question, prompt has 9k words, vllm takes 12 seconds, sglang 8 seconds. Is there any configuration that can improve performance?

i only use vscode + REST Client pulgin test v1/chat/completions,prompt is long content of the document,let LLM summarize,the maximum output length requirement is 500。

because there were some issues with sglang 0.4.0, I just tried sglang 0.3.2 again,it's takes 6s

@JaheimLee
Copy link

JaheimLee commented Dec 20, 2024

@jeejeelee

GPU L40, Qwen2.5-32B-GPTQ-Int4 same question, prompt has 9k words, vllm takes 12 seconds, sglang 8 seconds. Is there any configuration that can improve performance?

i only use vscode + REST Client pulgin test v1/chat/completions,prompt is long content of the document,let LLM summarize,the maximum output length requirement is 500。

because there were some issues with sglang 0.4.0, I just tried sglang 0.3.2 again,it's takes 6s

Have you tried “gptq” kernel? In my case, “gptq” kernel is faster than “marlin” kernel. I'm not sure whether it's a bug or not. My GPU is 3090

@noooop
Copy link
Contributor

noooop commented Dec 20, 2024

https://github.com/noooop/vllm/blob/f13a07b1f8c11ddbdc53b40f1fbb24bf3166b900/vllm/model_executor/layers/quantization/gptq.py#L242C1-L245C62

  1. “gptq” kernel use gemm, This may be useful for large batch sizes.

  2. neuralmagic blog show: marlin is better than GPT AWQ FP16 in all batch sizes

image

  1. Kernel performance may be related to the device

  2. I'm actually looking at how to use quantize+float16 in gptq,

like awq
https://github.com/noooop/vllm/blob/f13a07b1f8c11ddbdc53b40f1fbb24bf3166b900/vllm/model_executor/layers/quantization/awq.py#L164C1-L166C48

  1. Test this by myself FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 is useful

@Flynn-Zh
Copy link

@jeejeelee @noooop I just tried gptq again,and it's basically the same as gptq_marlin

@noooop
Copy link
Contributor

noooop commented Dec 25, 2024

code

setting

Offline inference

prefills

  • input_len = 8000
  • output_len = 16
  • num_prompts = 11
max_num_batched_tokens vllm 0.6.4 + gptq_marlin vllm 0.6.4 + gptq sglang 0.4.0.post2
1024 4.15 3.67 4.16
512 4.19 4.58 4.27
256 4.34 6.58 4.26
128 4.53 11.52 4.50
64 5.53 21.95 5.21
32 8.59 18.12 7.72

decoding

  • input_len = 8000
  • output_len = 512
  • num_prompts = 11
decoding
vllm 0.6.4 + flash attention 16.74334423
vllm 0.6.4 + flashinfer 16.76823786
sglang 0.4.0.post2 16.09388748

conclusion

  1. Offline inference, using chunked prefill, vllm and sglang are almost the same speed.
  2. marlin (MarlinLinearKernel) work well Almost all max_num_batched_tokens.
  3. gptq (ExllamaLinearKernel) Probably works well at >1024, but not much better.
  4. There is almost no difference in speed between flashinfer and flash attention.
  5. There is almost no difference in speed between vllm 0.6.4 and vllm 0.6.5

Situations not tested

  1. vllm default scheduler (not using chunked prefill) oom on my 4090
  2. MacheteLinearKernel requires capability 90, current (4090) compute capability is 89
  3. Maybe vllm and sglang webserver have different speeds.
  4. Maybe vllm and sglang have different output lengths.
  5. Maybe hit some kind of cache

vllm and sglang use almost the same mlp and attentions implementations, this code has been optimized for years.
At least for offline testing, the speed can't be that much different.

I'm not very familiar with webserver and need other experts to help.

@Flynn-Zh
Copy link

Flynn-Zh commented Dec 26, 2024

test result, it can be stably reproduced,it's all the first time calling:
vllm
sglang
run server cmd:
vllm-cmd
sglang-cmd
hardware L40*1,vllm 0.6.5 and sglang 0.4.0.post2 use the same L40

@noooop
Copy link
Contributor

noooop commented Dec 26, 2024

test result, it can be stably reproduced,it's all the first time calling:

vllm output 283 tokens, use 117804 ms
sgl output 328 tokens, use 6226 ms

how could it happen?

@Flynn-Zh
Can you run an offline test?

https://github.com/noooop/snippet/blob/main/benchmarks/test_gptq/main.py

@Flynn-Zh
Copy link

result.txt
@noooop are some errors in executing main.py

@Flynn-Zh
Copy link

modify main.by and run offline test again, the result is:
result.txt

@noooop

@noooop
Copy link
Contributor

noooop commented Dec 26, 2024

@noooop are some errors in executing main.py

I'm very sorry, I added the unsupported parameter enforce_eager to sgl.Engine but didn't test it.

Summarize

@Flynn-Zh
modify main.by and run offline test again, the result is:
result.txt

hardware L40*1

Offline inference

prefills

  • input_len = 8000
  • output_len = 16
  • num_prompts = 11

using chunked prefill

batchsize vllm + gptq_marlin vllm + gptq sglang 0.4.0.post2
1024 2.41 3.01 2.33
512 2.47 3.43 2.35
256 2.57 4.14 2.49
128 2.80 6.51 2.79
64 3.83 11.97 3.82
32 7.10 13.10 7.21

vllm default scheduler

method
gptq_marlin 2.33
gptq 2.35
gptq_marlin + enforce_eager 2.49
gptq + enforce_eager 2.79

decoding

  • input_len = 8000
  • output_len = 512
  • num_prompts = 11
decoding
vllm chunked prefill = 1024 flash attention 15.50
vllm chunked prefill = 1024 flashinfer 15.86
vllm default scheduler + gptq_marlin 16.24
vllm default scheduler + gptq 15.88
vllm default scheduler + gptq_marlin + enforce_eager 16.07
vllm default scheduler + gptq + enforce_eager 16.07
sglang 0.4.0.post2 10.41

conclusion

  1. for prefills: sglang is similar to vllm

  2. for decoding: sglang 10.41 vs vllm (under all configurations) 15 ~ 16. Really faster.

  3. for vllm

L40 864GB/s
4090 1008 GB/s

So 4090 prefill is slower than L40, but decoding is almost the same. very reasonable

  1. I don't know why,but In the decoding stage, sglang is indeed faster than vllm

@noooop
Copy link
Contributor

noooop commented Dec 26, 2024

@jeejeelee
Come and take a look

5. I don't know why,but In the decoding stage, sglang is indeed faster than vllm

@noooop
Copy link
Contributor

noooop commented Dec 26, 2024

@Flynn-Zh

reference to LLM inference speed of light

Qwen2.5-32B-GPTQ-Int4 18G

18G / 864GB/s = 20ms

prefills 2.33 s + decoding 20ms * (512-16)= 12.33 s

This does not take into account kvcache

  • vLLM 15 ~ 16s is very reasonable
  • sglang 10.41s can't happen

@Flynn-Zh
Copy link

@Flynn-Zh

reference to LLM inference speed of light

Qwen2.5-32B-GPTQ-Int4 18G

18G / 864GB/s = 20ms

prefills 2.33 s + decoding 20ms * (512-16)= 12.33 s

This does not take into account kvcache

  • vLLM 15 ~ 16s is very reasonable
  • sglang 10.41s can't happen

What black technology does sglang have?

@noooop
Copy link
Contributor

noooop commented Dec 26, 2024

@Flynn-Zh
reference to LLM inference speed of light
Qwen2.5-32B-GPTQ-Int4 18G
18G / 864GB/s = 20ms
prefills 2.33 s + decoding 20ms * (512-16)= 12.33 s
This does not take into account kvcache

  • vLLM 15 ~ 16s is very reasonable
  • sglang 10.41s can't happen

What black technology does sglang have?

I've been thinking the same thing.
What black technology does sglang have?

  • speculative sampling ?

  • sliding window ?

  • The prefix cache can only be used in the prefills stage, pass

@jeejeelee
Copy link
Collaborator

@jeejeelee Come and take a look

  1. I don't know why,but In the decoding stage, sglang is indeed faster than vllm

We can run a profiler to investigate it

@Flynn-Zh
Copy link

Flynn-Zh commented Dec 27, 2024

reference to LLM inference speed of light

@noooop What I understand is that the calculation method in this article is suitable for the MHA model, but qwen2.5 is the GQA model. I don't know if my understanding is correct?

@noooop
Copy link
Contributor

noooop commented Dec 27, 2024

GQA only affects the calculation of kv cache delay. We do not consider kv cache at all.

@noooop
Copy link
Contributor

noooop commented Dec 27, 2024

let's try NVIDIA Nsight profile

4090 profile.zip

input_len = 8000
output_len = 16
num_prompts = 1
chunked prefill size = 1024

for vllm

overall

vllm-1

prefill * 8 & decoding * 16. straightforward

Enlarge the decoding part

vllm-2

One decoding step takes 25ms, very reasonable

Summarize for for vllm Kernels

  • prefill Linear void marlin::Marlin..... ~445.982 μs [1]
  • prefill attention void flash_fwd_splitkv_kernel..... 114.336 μs ~ 512.510 μs [3]
  • decoding Linear void marlin::Marlin..... ~164.544 μs [2]
  • decoding attention void flash_fwd_splitkv_kernel..... ~43.168 μs [4]

for sglang

overall

sgl-1

Can only roughly see,prefill * 8 & decoding * 16

Enlarge the decoding part

sgl-2

One decoding step takes 24ms, very reasonable

Summarize for for vllm Kernels

Can not found cuda kernel information,
So there is no way to compare it with vllm

How to use NVIDIA Nsight profile

  1. install

https://developer.nvidia.com/nsight-systems/get-started

Download for Linux on x86_64

Nsight Systems 2024.7.1 Full Version

Download .run Installer

apt install nsight-systems Don't work for me

  1. profile
nsys profile -w true  -o vllm -f true -x true python test_vllm.py
nsys profile -w true  -o sgl -f true -x true python test_sgl.py

code

@Flynn-Zh

Looking forward to your profile

@noooop
Copy link
Contributor

noooop commented Dec 27, 2024

@jeejeelee

I'm not familiar with sglang. Is there any better way to profile the sglang?

Or do I need to add parameters to nsys profile ?

@Bryce1010
Copy link
Contributor

Keep an eye on it--I'm curious why it's happening. Does it only occur with Qwen2.5-32B-Instruct-GPTQ-Int4,or does it affect other models too?

@noooop
Copy link
Contributor

noooop commented Dec 27, 2024

Keep an eye on it--I'm curious why it's happening. Does it only occur with Qwen2.5-32B-Instruct-GPTQ-Int4,or does it affect other models too?

Yes, it's very strange.

I think the 4090 results are obviously reasonable, and the L40 results are very unreasonable.

I'm trying to determine how it was triggered.

sglang 10.41s feels like it is not read any kvcache in the decoding stage .

@Flynn-Zh
Please help me test how long a L40 sglang decoding step takes.

@jeejeelee
Copy link
Collaborator

@jeejeelee

I'm not familiar with sglang. Is there any better way to profile the sglang?

Or do I need to add parameters to nsys profile ?

I usually use torch.profiler

@Flynn-Zh
Copy link

image

my driver is 535.154

@Flynn-Zh
Copy link

Keep an eye on it--I'm curious why it's happening. Does it only occur with Qwen2.5-32B-Instruct-GPTQ-Int4,or does it affect other models too?

Yes, it's very strange.

I think the 4090 results are obviously reasonable, and the L40 results are very unreasonable.

I'm trying to determine how it was triggered.

sglang 10.41s feels like it is not read any kvcache in the decoding stage .

@Flynn-Zh Please help me test how long a L40 sglang decoding step takes.

sgl.zip
@noooop

@noooop
Copy link
Contributor

noooop commented Dec 28, 2024

sgl.zip

for sglang

overall

L40-sgl-1

  • last prefill 218ms

L40-sgl-2

  • decoding 15.6ms Awesome!

for vllm

overall

L40-vllm-1

  • last prefill 317 ms ?

L40-vllm-2

  • decoding 30 ms ?

vllm Kernels

  • prefill Linear exllama??? why not Marlin ??? [1]
  • prefill attention void flash_fwd_splitkv_kernel..... ok [3]
  • decoding Linear gptq:gmm??? why not Marlin ??? [2]
  • decoding attention void flash_fwd_splitkv_kernel..... ok [4]

conclusion

So I think L40 is slower because vllm does not use Marlin? but why?

  • L40 does not support Marlin?
  • wrong configuration?
  • bug?
  • There is a possibility that because of the driver version problem, vllm does not use Marlin but uses exllama, but why does sgl work?

@noooop
Copy link
Contributor

noooop commented Dec 28, 2024

image

Looking at the previous log, it seems that MarlinLinearKernel is supported.

What did we miss?

@noooop
Copy link
Contributor

noooop commented Dec 28, 2024

Is it caused by setting the quantization parameter?

  • quantization = "gptq_marlin"

I thought it was the same as quantization = None, but maybe it's not.

https://github.com/noooop/snippet/blob/d3f69b532b18639b791218e74b7cfe9100816726/benchmarks/test_gptq/test_vllm.py#L117C1-L117C38

Force using MarlinLinearKernel seems really fast!

args.environs = {
"VLLM_DISABLED_KERNELS":
"GPTQMarlinLinearMethod,MacheteLinearKernel"
}

batchsize vllm + gptq_marlin vllm + gptq sglang 0.4.0.post2
1024 2.41 3.01 2.33

But why did I add quantization = "gptq_marlin" and also use the Marlin kernel?

@Flynn-Zh

Please try:

  1. set args.quantization = "gptq_marlin"; None; "gpt"
  2. Force using MarlinLinearKernel

@noooop
Copy link
Contributor

noooop commented Dec 28, 2024

for 4090
vllm.zip

vllm-gptq_marlin

  • quantization = gptq_marlin use marlin

vllm-None

  • quantization = None use marlin

vllm-gpt

  • quantization = gpt use exllama <- This one is a little slower

@Flynn-Zh
Copy link

  • quantization = "gptq_marlin"

@noooop the situation tested yesterday
image
but there were no mistakes today,just reinstalled a lower version of Nsight Systems
image
vllm.zip

@noooop
Copy link
Contributor

noooop commented Dec 28, 2024

L40-vllm-3

  • last prefill 317 ms vs 254 ms, yes.
  • Compare sglang 218ms,the difference is not such big

L40-vllm-4

  • decoding 32ms??? WTF
  • Even slower than Linear gptq:gmm 30 ms
  • This is probably why Marlin is not used by default

Summarize for for vllm Kernels

  • prefill Linear void marlin::Marlin..... yes [1]
  • prefill attention void flash_fwd_splitkv_kernel..... yes [3]
  • decoding Linear void marlin::Marlin..... ~ yes [2]
  • decoding attention void flash_fwd_splitkv_kernel..... ~yes [4]

conclusion

  • vllm Marlin is slower on l40
  • Especially during the decoding stage

There may be various reasons, but why is there no problem with sglang.

I suggest you open a new issue named "Marlin slowe in L40"

let's double check

Wait a minute, there may also be a problem with the Marlin implementation of sglang

Specifications

GPU Memory Bandwidth | 864GB/s

reference to LLM inference speed of light

Qwen2.5-32B-GPTQ-Int4 18G

18G / 864GB/s = 20ms

  • vllm Marlin decoding 32ms

  • vllm Linear gptq:gmm 30 ms

  • sglang decoding 15.6ms

  • It may be that vllm Marlin is relatively slow.

  • It may also be that there is a problem with the sglang Marlin implementation.

@noooop
Copy link
Contributor

noooop commented Dec 31, 2024

Prepare to wait for further testing after #11493 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

No branches or pull requests

6 participants