Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support double sparsity #1459

Merged
merged 6 commits into from
Oct 14, 2024
Merged

Conversation

andy-yang-1
Copy link
Contributor

@andy-yang-1 andy-yang-1 commented Sep 18, 2024

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Speedup Evaluation

Run double sparsity with:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton --disable-cuda-graph \
    --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \
    --input-len 20000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 16 \
    --ds-heavy-token-num 1024 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 70000

Benchmark ...
Prefill. latency: 7.83636 s, throughput:   7656.62 token/s
Decode.  latency: 0.02351 s, throughput:    127.58 token/s
Decode.  latency: 0.02124 s, throughput:    141.22 token/s
Decode.  latency: 0.02037 s, throughput:    147.26 token/s
Decode.  latency: 0.01950 s, throughput:    153.81 token/s
Decode.  latency: 0.01935 s, throughput:    155.04 token/s
Decode.  median latency: 0.01923 s, median throughput:    156.04 token/s
Total. latency: 11.821 s, throughput:   5126.36 token/s

Original triton implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 7.79627 s, throughput:   7695.98 token/s
Decode.  latency: 0.07196 s, throughput:     41.69 token/s
Decode.  latency: 0.06514 s, throughput:     46.05 token/s
Decode.  latency: 0.06475 s, throughput:     46.33 token/s
Decode.  latency: 0.06463 s, throughput:     46.41 token/s
Decode.  latency: 0.06457 s, throughput:     46.46 token/s
Decode.  median latency: 0.06487 s, median throughput:     46.25 token/s
Total. latency: 20.720 s, throughput:   2924.74 token/s

Original flashinfer implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend flashinfer \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 5.68892 s, throughput:  10546.83 token/s
Decode.  latency: 0.03240 s, throughput:     92.60 token/s
Decode.  latency: 0.02993 s, throughput:    100.23 token/s
Decode.  latency: 0.02970 s, throughput:    101.01 token/s
Decode.  latency: 0.02959 s, throughput:    101.39 token/s
Decode.  latency: 0.02959 s, throughput:    101.38 token/s
Decode.  median latency: 0.02961 s, median throughput:    101.32 token/s
Total. latency: 11.585 s, throughput:   5231.00 token/s

With Llama-3.1-8B:

# Double Sparsity
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 32 \
    --ds-heavy-channel-type k \
    --ds-heavy-token-num 3000 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 42.96801 s, throughput:   4189.16 token/s
Decode.  latency: 0.02843 s, throughput:    105.50 token/s
Decode.  latency: 0.02518 s, throughput:    119.16 token/s
Decode.  latency: 0.02465 s, throughput:    121.72 token/s
Decode.  latency: 0.02442 s, throughput:    122.84 token/s
Decode.  latency: 0.02434 s, throughput:    123.24 token/s
Decode.  median latency: 0.02421 s, median throughput:    123.90 token/s
Total. latency: 47.793 s, throughput:   3778.77 token/s

# Triton
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 43.17160 s, throughput:   4169.41 token/s
Decode.  latency: 0.06359 s, throughput:     47.18 token/s
Decode.  latency: 0.05965 s, throughput:     50.30 token/s
Decode.  latency: 0.05927 s, throughput:     50.62 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  median latency: 0.05913 s, median throughput:     50.73 token/s
Total. latency: 54.950 s, throughput:   3286.63 token/s

# Flashinfer
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend flashinfer \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 27.50800 s, throughput:   6543.55 token/s
Decode.  latency: 0.03014 s, throughput:     99.54 token/s
Decode.  latency: 0.02834 s, throughput:    105.86 token/s
Decode.  latency: 0.02821 s, throughput:    106.36 token/s
Decode.  latency: 0.02819 s, throughput:    106.41 token/s
Decode.  latency: 0.02823 s, throughput:    106.28 token/s
Decode.  median latency: 0.02821 s, median throughput:    106.34 token/s
Total. latency: 33.125 s, throughput:   5452.12 token/s

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@merrymercy
Copy link
Contributor

merrymercy commented Sep 19, 2024

Great work. Some tips for rebasing:

python/sglang/srt/layers/radix_attention.py Outdated Show resolved Hide resolved
python/sglang/srt/layers/test_ds_kernel.py Outdated Show resolved Hide resolved
python/sglang/srt/mem_cache/memory_pool.py Outdated Show resolved Hide resolved
@Ying1123 Ying1123 mentioned this pull request Sep 22, 2024
37 tasks
@merrymercy merrymercy mentioned this pull request Sep 22, 2024
2 tasks
@ghost
Copy link

ghost commented Sep 24, 2024

Quick question @andy-yang-1 - Does this PR support just Double Sparsity or DS-Offload as well?

@andy-yang-1
Copy link
Contributor Author

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.

@fengyang95
Copy link

Is there a plan to merge this PR?

@merrymercy
Copy link
Contributor

merrymercy commented Oct 11, 2024

Yes. It should be merged within one week.
@andy-yang-1 please

  1. Resolve the conflicts.
  2. Add an end-to-end accuracy unit test

@merrymercy
Copy link
Contributor

Please fix the lint error and add an end-to-end accuracy test

python/sglang/srt/model_executor/forward_batch_info.py Outdated Show resolved Hide resolved
python/sglang/test/Llama-3.1-8B-Instruct.jsonconfig Outdated Show resolved Hide resolved
test/srt/test_double_sparsity.py Show resolved Hide resolved
@merrymercy merrymercy changed the title [WIP] Support double sparsity Support double sparsity Oct 14, 2024
@merrymercy
Copy link
Contributor

merrymercy commented Oct 14, 2024

Give two example commands and past their results in the description of this PR. This is for tracking the progress. It should be something like this

# baseline
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8

# double sparsity
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 8 --enable-double-sparsity ...

@merrymercy
Copy link
Contributor

@andy-yang-1 Can you also paste the latency results?

test/srt/test_double_sparsity.py Outdated Show resolved Hide resolved
test/srt/run_suite.py Outdated Show resolved Hide resolved
@merrymercy merrymercy enabled auto-merge (squash) October 14, 2024 08:32
@merrymercy merrymercy disabled auto-merge October 14, 2024 09:00
@merrymercy merrymercy merged commit 061e546 into sgl-project:main Oct 14, 2024
10 of 11 checks passed
@merrymercy
Copy link
Contributor

@andy-yang-1 Thanks for the contribution. It is merged.

@max99x
Copy link
Contributor

max99x commented Oct 14, 2024

How does one generate the ds-channel-config to be able to use this?

@fengyang95
Copy link

I noticed that CUDA graph is not currently supported. Are there any plans to support it? @andy-yang-1

@andy-yang-1
Copy link
Contributor Author

@max99x You can use this link to generate channel config file.

@fengyang95 We may support it in the next PR

@fengyang95
Copy link

fengyang95 commented Oct 18, 2024

hi @andy-yang-1 Does this support the deepseek-v2 architecture? How can I obtain the config for this structure? I see that the example here https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py only support llama/mixtral arch.

@fengyang95
Copy link

fengyang95 commented Oct 19, 2024

@andy-yang-1 I tried running the deepseek-v2 model, but encountered the following issue:

File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument index in method wrapper_CUDA_gather)
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/__init__.py", line 49, in forward
    return self.forward_extend(q, k, v, layer, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/tiger/custome_sglang/python/sglang/srt/layers/attention/double_sparsity_backend.py", line 162, in forward_extend
    k_label = torch.gather(
              ^^^^^^^^^^^^^
RuntimeError: Size does not match at dimension 1 expected index [7, 128, 16] to be smaller than self [7, 1, 576] apart from dimension 2

@andy-yang-1
Copy link
Contributor Author

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

@fengyang95
Copy link

@fengyang95 I haven't added support for deepseek-v2 model. I may add support for this later

@andy-yang-1 Thank you very much! Looking forward to support for deepseek-v2 and cuda graph.

@shreyansh26
Copy link

@andy-yang-1 - Loved the paper! I was trying this out and I am facing a few issues generating the config file using the mentioned script.

  1. The line cos, sin = m.rotary_emb(v, seq_len=kv_seq_len) in stat_qk_max_hook of get_calib_qk_feat gives an error
TypeError: LlamaRotaryEmbedding got an unexpected keyword argument 'seq_len'

I replaced it with cos, sin = m.rotary_emb(v, position_ids=position_ids) which works. I'm not sure if that is correct but LlamaRotaryEmbedding indeed doesn't have the seq_len param

  1. In the config file that gets generated, I only get keys of the form model.layers.{layer_num}.self_attn but the config file present in the test folder has keys in the form of model.layers.{layer_num}.self_attn.q_proj, model.layers.{layer_num}.self_attn.k_proj and model.layers.{layer_num}.self_attn.qk_proj. How were these generated?
    On using my generated config with sglang, I am getting error of the type - Key model.layers.0.self_attn.k_proj was not found.

Any help on how to run this would be appreciated.

@andy-yang-1
Copy link
Contributor Author

@shreyansh26 The first problem is caused by older version of transformers, and I will update the base repo to fix it this week.
The q_outlier_config/k_outlier_config is generated with get_calib_feat function, and the qk_outlier_config is generated with get_qk_calib_feat function. You can merge this two config together to get all configs. I will also update it this week.

@shreyansh26
Copy link

shreyansh26 commented Nov 7, 2024

Thank you.
There may be another discrepancy, in get_calib_feat, with the following condition, k_proj gets filtered out because of GQA.

if y.shape[-1] != model.config.hidden_size:
    return

But in the Llama-3.1-8B-Instruct config file, k_proj keys are also present.

@andy-yang-1
Copy link
Contributor Author

@shreyansh26 Hi, I have updated the main repo. Can you try with this code?

@shreyansh26
Copy link

Thank you @andy-yang-1!! This is working perfectly now.

@yuguo-Jack
Copy link

@vnkc1 Hi, this PR doesn't support DS-Offload for now. DS-Offload may be integrated in other PR if needed.
Is there a plan to support DS-Offload in Sglang?

@hcyz33
Copy link

hcyz33 commented Jan 13, 2025

Motivation

  • Support double sparsity (post-training sparse attention) for long context inference in SGLang
  • See paper

Modifications

  • Add triton implementation in sglang/python/sglang/srt/layers/sparse_decode_attention.py
  • Add serving-related parts

Speedup Evaluation

Run double sparsity with:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton --disable-cuda-graph \
    --ds-channel-config-path /path/to/lmsys/longchat-7b-v1.5-32k.json \
    --input-len 20000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 16 \
    --ds-heavy-token-num 1024 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 70000

Benchmark ...
Prefill. latency: 7.83636 s, throughput:   7656.62 token/s
Decode.  latency: 0.02351 s, throughput:    127.58 token/s
Decode.  latency: 0.02124 s, throughput:    141.22 token/s
Decode.  latency: 0.02037 s, throughput:    147.26 token/s
Decode.  latency: 0.01950 s, throughput:    153.81 token/s
Decode.  latency: 0.01935 s, throughput:    155.04 token/s
Decode.  median latency: 0.01923 s, median throughput:    156.04 token/s
Total. latency: 11.821 s, throughput:   5126.36 token/s

Original triton implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend triton \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 7.79627 s, throughput:   7695.98 token/s
Decode.  latency: 0.07196 s, throughput:     41.69 token/s
Decode.  latency: 0.06514 s, throughput:     46.05 token/s
Decode.  latency: 0.06475 s, throughput:     46.33 token/s
Decode.  latency: 0.06463 s, throughput:     46.41 token/s
Decode.  latency: 0.06457 s, throughput:     46.46 token/s
Decode.  median latency: 0.06487 s, median throughput:     46.25 token/s
Total. latency: 20.720 s, throughput:   2924.74 token/s

Original flashinfer implementation:

python -m sglang.bench_latency --model-path lmsys/longchat-7b-v1.5-32k \
    --attention-backend flashinfer \
    --input-len 20000 --output-len 200 \
    --batch-size 3

Benchmark ...
Prefill. latency: 5.68892 s, throughput:  10546.83 token/s
Decode.  latency: 0.03240 s, throughput:     92.60 token/s
Decode.  latency: 0.02993 s, throughput:    100.23 token/s
Decode.  latency: 0.02970 s, throughput:    101.01 token/s
Decode.  latency: 0.02959 s, throughput:    101.39 token/s
Decode.  latency: 0.02959 s, throughput:    101.38 token/s
Decode.  median latency: 0.02961 s, median throughput:    101.32 token/s
Total. latency: 11.585 s, throughput:   5231.00 token/s

With Llama-3.1-8B:

# Double Sparsity
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --ds-channel-config-path /path/to/meta-llama/Llama-3.1-8B-Instruct.json \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --enable-double-sparsity \
    --ds-heavy-channel-num 32 \
    --ds-heavy-channel-type k \
    --ds-heavy-token-num 3000 \
    --ds-sparse-decode-threshold 0 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 42.96801 s, throughput:   4189.16 token/s
Decode.  latency: 0.02843 s, throughput:    105.50 token/s
Decode.  latency: 0.02518 s, throughput:    119.16 token/s
Decode.  latency: 0.02465 s, throughput:    121.72 token/s
Decode.  latency: 0.02442 s, throughput:    122.84 token/s
Decode.  latency: 0.02434 s, throughput:    123.24 token/s
Decode.  median latency: 0.02421 s, median throughput:    123.90 token/s
Total. latency: 47.793 s, throughput:   3778.77 token/s

# Triton
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend triton \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 43.17160 s, throughput:   4169.41 token/s
Decode.  latency: 0.06359 s, throughput:     47.18 token/s
Decode.  latency: 0.05965 s, throughput:     50.30 token/s
Decode.  latency: 0.05927 s, throughput:     50.62 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  latency: 0.05906 s, throughput:     50.80 token/s
Decode.  median latency: 0.05913 s, median throughput:     50.73 token/s
Total. latency: 54.950 s, throughput:   3286.63 token/s

# Flashinfer
python -m sglang.bench_latency --model-path meta-llama/Llama-3.1-8B-Instruct \
    --attention-backend flashinfer \
    --input-len 60000 --output-len 200 \
    --batch-size 3 \
    --max-total-tokens 200000

Benchmark ...
Prefill. latency: 27.50800 s, throughput:   6543.55 token/s
Decode.  latency: 0.03014 s, throughput:     99.54 token/s
Decode.  latency: 0.02834 s, throughput:    105.86 token/s
Decode.  latency: 0.02821 s, throughput:    106.36 token/s
Decode.  latency: 0.02819 s, throughput:    106.41 token/s
Decode.  latency: 0.02823 s, throughput:    106.28 token/s
Decode.  median latency: 0.02821 s, median throughput:    106.34 token/s
Total. latency: 33.125 s, throughput:   5452.12 token/s

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

I found that the throughput of prefill is lower when enable DS attention(from 6543.55 to 4189.16 ). The possible reason is that you use triton as attention-backend. Is it possible to use flashinfer attention in prefill to increase the throughput of prefill.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants