Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move apply_torchao_config_ to model_runner #2342

Merged
merged 6 commits into from
Dec 5, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Dec 4, 2024

Summary:
Previously we need to apply_torchao_config_ to each model manually, this PR changes it to run on the entire model, we can also add autoquant in the future

Test Plan:

llama:

python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile

Benchmark ...
Prefill. latency: 0.03361 s, throughput: 3808.05 token/s
Decode. latency: 0.01227 s, throughput: 81.50 token/s
Decode. latency: 0.01195 s, throughput: 83.70 token/s
Decode. latency: 0.01181 s, throughput: 84.65 token/s
Decode. latency: 0.01176 s, throughput: 85.05 token/s
Decode. latency: 0.01133 s, throughput: 88.25 token/s
Decode. median latency: 0.01176 s, median throughput: 85.05 token/s
Total. latency: 0.115 s, throughput: 1179.56 token/s

python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile —torchao-config int4wo-128 Benchmark ...
Prefill. latency: 0.11769 s, throughput: 1087.60 token/s
Decode. latency: 0.00687 s, throughput: 145.47 token/s
Decode. latency: 0.00648 s, throughput: 154.25 token/s
Decode. latency: 0.00641 s, throughput: 156.01 token/s
Decode. latency: 0.00635 s, throughput: 157.53 token/s
Decode. latency: 0.00634 s, throughput: 157.74 token/s
Decode. median latency: 0.00644 s, median throughput: 155.28 token/s
Total. latency: 0.163 s, throughput: 834.21 token/s

qwen:

python3 -m sglang.bench_one_batch --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int4wo-128 
original:
Benchmark ...
Prefill. latency: 0.06101 s, throughput: 2097.86 token/s
Decode. latency: 0.00532 s, throughput: 187.93 token/s
Decode. latency: 0.00524 s, throughput: 190.88 token/s
Decode. latency: 0.00520 s, throughput: 192.43 token/s
Decode. latency: 0.00513 s, throughput: 194.97 token/s
Decode. latency: 0.00507 s, throughput: 197.26 token/s
Decode. median latency: 0.00513 s, median throughput: 194.97 token/s
Total. latency: 0.097 s, throughput: 1400.16 token/s

after change:

Benchmark ...
Prefill. latency: 0.05830 s, throughput: 2195.38 token/s
Decode. latency: 0.00517 s, throughput: 193.50 token/s
Decode. latency: 0.00508 s, throughput: 196.71 token/s
Decode. latency: 0.00512 s, throughput: 195.36 token/s
Decode. latency: 0.00508 s, throughput: 196.97 token/s
Decode. latency: 0.00504 s, throughput: 198.44 token/s
Decode. median latency: 0.00508 s, median throughput: 196.97 token/s
Total. latency: 0.094 s, throughput: 1449.19 token/s
Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
Previously we need to apply_torchao_config_ to each model manually, this PR
changes it to run on the entire model, we can also add autoquant in the future

Test Plan:

llama:

python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg
s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile

Benchmark ...
Prefill. latency: 0.03361 s, throughput:   3808.05 token/s
Decode.  latency: 0.01227 s, throughput:     81.50 token/s
Decode.  latency: 0.01195 s, throughput:     83.70 token/s
Decode.  latency: 0.01181 s, throughput:     84.65 token/s
Decode.  latency: 0.01176 s, throughput:     85.05 token/s
Decode.  latency: 0.01133 s, throughput:     88.25 token/s
Decode.  median latency: 0.01176 s, median throughput:     85.05 token/s
Total. latency:  0.115 s, throughput:   1179.56 token/s

python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg
s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile —torchao-config int4wo-128
Benchmark ...
Prefill. latency: 0.11769 s, throughput:   1087.60 token/s
Decode.  latency: 0.00687 s, throughput:    145.47 token/s
Decode.  latency: 0.00648 s, throughput:    154.25 token/s
Decode.  latency: 0.00641 s, throughput:    156.01 token/s
Decode.  latency: 0.00635 s, throughput:    157.53 token/s
Decode.  latency: 0.00634 s, throughput:    157.74 token/s
Decode.  median latency: 0.00644 s, median throughput:    155.28 token/s
Total. latency:  0.163 s, throughput:    834.21 token/s

qwen:

python3 -m sglang.bench_one_batch --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int4wo-128

original:
Benchmark ...
Prefill. latency: 0.06101 s, throughput:   2097.86 token/s
Decode.  latency: 0.00532 s, throughput:    187.93 token/s
Decode.  latency: 0.00524 s, throughput:    190.88 token/s
Decode.  latency: 0.00520 s, throughput:    192.43 token/s
Decode.  latency: 0.00513 s, throughput:    194.97 token/s
Decode.  latency: 0.00507 s, throughput:    197.26 token/s
Decode.  median latency: 0.00513 s, median throughput:    194.97 token/s
Total. latency:  0.097 s, throughput:   1400.16 token/s

after change:

Benchmark ...
Prefill. latency: 0.05830 s, throughput:   2195.38 token/s
Decode.  latency: 0.00517 s, throughput:    193.50 token/s
Decode.  latency: 0.00508 s, throughput:    196.71 token/s
Decode.  latency: 0.00512 s, throughput:    195.36 token/s
Decode.  latency: 0.00508 s, throughput:    196.97 token/s
Decode.  latency: 0.00504 s, throughput:    198.44 token/s
Decode.  median latency: 0.00508 s, median throughput:    196.97 token/s
Total. latency:  0.094 s, throughput:   1449.19 token/s
Reviewers:

Subscribers:

Tasks:

Tags:
@merrymercy
Copy link
Contributor

Could you fix the CI error?

@merrymercy merrymercy merged commit 9cc733b into sgl-project:main Dec 5, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants