Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SGLang as Potential Backend for Evaluation #2703

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

Monstertail
Copy link

@Monstertail Monstertail commented Feb 15, 2025

  1. We add the sglang as a potential backend for evaluation, which can preserve the accuracy but faster than vLLM in many workloads. See here.
  2. We support the data parallel and tensor parallel in 1 as well, thanks to simple yet great SGLang engine.
  3. We add the tests for sglang here.
  4. We initialize the dependencies of SGlang here, but we need future support as the dependencies of SGLang also includes Flashinfer and a simple pyproject.toml cannot include. See here. A simple solution is to install the SGLang in advance.

As some arguments of SGlang engine are different from other backends, we are thinking about to provide a simple doc to tell the users how to run in the future. Here is a simple example:

 CUDA_VISIBLE_DEVICES=2,3,4,5 lm_eval --model sglang \
    --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto,tp_size=2,dp_size=2\
    --tasks gsm8k_cot \
    --device "cuda" \
	   --apply_chat_template \
	   --fewshot_as_multiturn \
	   --num_fewshot 8 \
	   --gen_kwargs temperature=0 \
    --batch_size auto \
    --seed 123

Please tell us what we need to provide or modify! Thanks to our SGlang team manager Chenyang @zhaochenyang20 , and co-contributor Xiaotong @XiaotongJiang

@CLAassistant
Copy link

CLAassistant commented Feb 15, 2025

CLA assistant check
All committers have signed the CLA.

@Monstertail
Copy link
Author

Monstertail commented Feb 15, 2025

[Speed compared with vLLM] A6000-48GB as our test bed.

[vLLM+disable cuda graph] speed=20.90it/s .

CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm     --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,enforce_eager=True,max_model_len=16384,enable_chunked_prefill=False     --tasks gsm8k_cot     --device cuda:0        --apply_chat_template           --fewshot_as_multiturn      --num_fewshot 8         --gen_kwargs temperature=0     --batch_size auto:4     --seed 123
image

[vLLM+default setting]speed=22.28it/s @mgoin Thanks for you comment:) I reran with your command in A6000 again.

CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123
image

[SGlang+default setting] speed=62.92it/s

CUDA_VISIBLE_DEVICES=0 lm_eval --model sglang     --model_args pr
etrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto    --tasks gsm8k_cot     --device "cuda"      --apply_chat_template           --fewshot
_as_multiturn      --num_fewshot 8         --gen_kwargs temperature=0     --batch_size auto     --seed 123

image

@Qubitium
Copy link
Contributor

@Monstertail SGLang is a beast! One quick note: right now the test uses SGLangLM directly. Can you add a sub-test so the test uses lm_eval.simple_evaluate or lm_eval.evaluate api? As most users would start the SGLang backend via the above two api

@zhaochenyang20
Copy link

@Qubitium Thanks! We have the world's most talented and diligent team. Enjoy the collaboration!

@Monstertail
Copy link
Author

@Monstertail SGLang is a beast! One quick note: right now the test uses SGLangLM directly. Can you add a sub-test so the test uses lm_eval.simple_evaluate or lm_eval.evaluate api? As most users would start the SGLang backend via the above two api

Thanks for your comment! I will take a look when I am more available:)

@mgoin
Copy link
Contributor

mgoin commented Feb 16, 2025

Hey @Monstertail congrats on the PR. If you want to compare the speed of vLLM vs SGLang, please use the default arguments in both cases equally and do not disable CUDA Graphs + lower the batch size for vLLM to artificially reduce the performance.

These parameters set in only in your vLLM command are not the default and lower performance:

  • enforce_eager=True disables CUDA Graphs which is crucial for a decode-heavy workload like gsm8k
  • --batch_size auto:4 spends more time tuning the batch size which takes time away from processing prompts.
  • gpu_memory_utilization=0.8 lowers the amount of kv cache space (although likely not important in this case)

For your SGLang command you stick to the defaults, so I think it is only fair to compare the defaults to avoid misleading comparison.

Here are my results on an H100 with commands that are normalized to just use the default in both cases, as in --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto and --batch_size auto.

vLLM speed=74.87it/s

CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123
...
2025-02-16:15:17:20,630 INFO     [evaluator.py:496] Running generate_until requests
Processed prompts: 100%|████████████████| 1319/1319 [00:17<00:00, 75.56it/s, est. speed input: 65360.29 toks/s, output: 9170.96 toks/s]
Running generate_until requests: 100%|███████████████████| 1319/1319 [00:17<00:00, 74.87it/s]
2025-02-16:15:17:43,540 INFO     [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto), gen_kwargs: (temperature=0), limit: None, num_fewshot: 8, batch_size: auto
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.5620|±  |0.0137|
|         |       |strict-match    |     8|exact_match|↑  |0.5124|±  |0.0138|

SGLang speed=69.01it/s

CUDA_VISIBLE_DEVICES=0 lm_eval --model sglang --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device "cuda" --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123
...
[2025-02-16 15:18:54] Running generate_until requests
Running generate_until requests: 100%|███████████████| 1319/1319 [00:19<00:00, 69.01it/s]
[2025-02-16 15:19:20] Output path not provided, skipping saving results aggregated
sglang (pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto), gen_kwargs: (temperature=0), limit: None, num_fewshot: 8, batch_size: auto
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.5648|±  |0.0137|
|         |       |strict-match    |     8|exact_match|↑  |0.5110|±  |0.0138|

@zhaochenyang20
Copy link

zhaochenyang20 commented Feb 16, 2025

Hey @mgoin . Thanks for pointing this out. We use your command on our H100, and get these results:

  • VLLM 51.17it/s
  • SGLang 108.58it/s
CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

Running generate_until requests: 100%|███████████████████| 1319/1319 [00:25<00:00, 51.17it/s]
2025-02-16:18:48:53,326 INFO     [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto), gen_kwargs: (temperature=0), limit: None, num_fewshot: 8, batch_size: auto
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match||0.5580|±  |0.0137|
|         |       |strict-match    |     8|exact_match||0.5064|±  |0.0138|
CUDA_VISIBLE_DEVICES=0 lm_eval --model sglang --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device "cuda" --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

Running generate_until requests: 100%|██████████████████| 1319/1319 [00:12<00:00, 108.58it/s]
[2025-02-16 18:51:55] Output path not provided, skipping saving results aggregated
sglang (pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto), gen_kwargs: (temperature=0), limit: None, num_fewshot: 8, batch_size: auto
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match||0.5527|±  |0.0137|
|         |       |strict-match    |     8|exact_match||0.5011|±  |0.0138|

We can also let community users to evaluate this and we will test on our A100 today.

@Monstertail
Copy link
Author

Monstertail commented Feb 16, 2025

Hi @mgoin, thanks for your comment. We also use your command on our A100, and get these results:
[vllm] 34.64it/s
image
[sglang] 88.77 it/s
image

We ran 5 times and compare the performance of these two backends here(all use your comment):
vllm(it/s):[34.64 ,35.38, 35.37, 35.26, 35.38] avg=35.21 it/s
sglang(it/s): [88.77, 101.60,102.43,101.44, 100.86] avg=99.02 it/s

We wonder whether other programs influenced your test, or could you please tell me your SGLang version?

@mgoin
Copy link
Contributor

mgoin commented Feb 17, 2025

I updated to the latest sglang==0.4.3 installed from pypi according to the docs here and vllm==0.7.2 from pypi, each in separate python 3.12 venvs. For my benchmarks here I am using the same commands as in my message before on a single H100.

vLLM 0.7.2 (it/s): [71.87, 73.51, 71.03, 71.77] avg=72.045 it/s
SGLang 0.4.3 (it/s): [69.22, 71.32, 72.04, 71.91] avg=71.1225 it/s

Note: If you really want to go for performance here, I would recommend enabling vLLM V1 on the same release - it can get over 200 iterations per second on this workload!

vLLM 0.7.2 with VLLM_USE_V1=1 enabled (it/s): [204.04, 208.03, 203.28, 205.14] avg=205.1225 it/s

VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123
...
2025-02-17:00:03:29,375 INFO     [evaluator.py:465] Running generate_until requests
Processed prompts: 100%|█████████████████| 1319/1319 [00:06<00:00, 208.46it/s, est. speed input: 180336.61 toks/s, output: 25490.72 toks/s]
Running generate_until requests: 100%|█████████████████| 1319/1319 [00:06<00:00, 204.04it/s]
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto), gen_kwargs: (temperature=0), limit: None, num_fewshot: 8, batch_size: auto
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.5588|±  |0.0137|
|         |       |strict-match    |     8|exact_match|↑  |0.5027|±  |0.0138|

Here is my full system configuration for additional clarity:

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.12.4 (main, Jul 25 2024, 22:42:01) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 555.42.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           8
CPU max MHz:                        4100.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hfi vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          3 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           128 MiB (64 instances)
L3 cache:                           120 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.3
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.3.dev57+g2ae889052.d20250210
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    32-63,96-127    1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    32-63,96-127    1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    32-63,96-127    1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     32-63,96-127    1               N/A
NIC0    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE    SYS     SYS     SYS     SYS
NIC1    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE    SYS     SYS     SYS     SYS
NIC2    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    SYS     SYS     SYS     SYS
NIC3    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE
NIC5    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE
NIC6    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE
NIC7    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7

LD_LIBRARY_PATH=/usr/local/cuda-12.5/lib64:/usr/local/cuda-12.5/lib64:
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

@Monstertail
Copy link
Author

Monstertail commented Feb 17, 2025

Hi, I confirmed that what I tested before and now was based on sglang==0.4.3 and vllm==0.7.2.
First, I would like to thank @mgoin for your discussion. I would say that both sglang and vllm are good backends. We become better through mutual learning and discussion to serve the entire community. Respect to vLLM team:)

Second, I tested vllm_v1 on A100 locally, and it's much faster than without v1 : (with v1) vllm speedup=130 it/s. Great! I bet users will like these improvements!

VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-1.5B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

Third, I noticed the sglang speed has still a gap in your side. It may be caused by Flashinfer is not correctly installed or other issues. Anyway, I do think, providing a new option is always not a bad thing. We can make the community better together:)

Last , I tested with a larger model (Qwen2-7B-Instruct) and noticed that SGlang is faster than vllm(with VLLM_USE_V1=1) , which is opposite when we have a 1.5B model. I post here because I think it's an interesting result can benefit both of us, not for competing :)

Testbed: a single A100 with Qwen2-7B-Instruct.
[vLLM] speed = 12.16 it/s

CUDA_VISIBLE_DEVICES=0 lm_eval --model sglang --model_args pretrained=Qwen/Qwen2-7B-Instruct,dtype=auto --tasks gsm8k_cot --device "cuda" --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

[vLLM with V1] speed =45.93 it/s(awesome 4x improvement with v1)

VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2-7B-Instruct,dtype=auto --tasks gsm8k_cot --device cuda:0 --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

[SGlang] speed=55.04 it/s

CUDA_VISIBLE_DEVICES=0 lm_eval --model sglang --model_args pretrained=Qwen/Qwen2-7B-Instruct,dtype=auto --tasks gsm8k_cot --device "cuda" --apply_chat_template --fewshot_as_multiturn  --num_fewshot 8 --gen_kwargs temperature=0 --batch_size auto --seed 123

@Monstertail
Copy link
Author

We also noticed that there is some accuracy gap between HF model card and lm-eval-harness due to parser. See here for details, and we provide a simple solution to bridge the gap.

@zhaochenyang20
Copy link

zhaochenyang20 commented Feb 17, 2025

@baberabb @lintangsutawika Our PR is near ready. Thanks so much for help

@Monstertail
Copy link
Author

Monstertail commented Feb 17, 2025

Todos to make the PR better(I will do the 1,2,4 late today, but 3 needs discussion):

  • 1. add a sub-test so the test uses lm_eval.simple_evaluate or lm_eval.evaluate API , thanks for @Qubitium 's comment;
  • 2. make the max_model_length shorter for test cases here to overcome the OOM issue and enable the test within a single A100;
  • 3. decide the installation dependency of SGlang (especially flashinfer) with lm-eval-harness team. Currently I initialized the partial dependencies of SGlang here, but the sglang cannot be installed because flashinfer dependency cannot be included within a pyproject.toml. One solution is to install sglang in advance.
  • 4.provide the command in the readme. Thanks @baberabb for comments.

)
if self.data_parallel_size <= 1:
self.model = sgl.Engine(**self.model_args)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we skip data parallel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think this can be simplified by merging two self.model=self.model = sgl.Engine(**self.model_args) to only one, then throw a warning if self.data_parallel_size > 1 because dp might be replaced by sglang_router in the future. See here. Shall we keep the warning?

"Either context_length or max_model_len may be provided, but not both"
)
# Initialize your sglang model here
self._max_length = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could also pass max_length kwarg as that's used thorough the library.

@baberabb
Copy link
Contributor

Hi! Thanks v. much for the PR. This is excellent! just a couple of nits in the comments. Could you also add a section in the main readme, like all the other backends, mainly providing the command and identifying any footguns. :)

With respect to flashinfer, is it a required dependency? If not we could add a warning to the init with install instructions, or add an assertion if it is required?

@Monstertail
Copy link
Author

Monstertail commented Feb 18, 2025

Hi @baberabb Thanks for your quick reply! Appreciate your comments:)

  1. as for readme of sglang, I update the Todos mentioned above;
  2. Flashinfer is a very fast attention backend adopted by SGLang and vLLM(recently), we highly recommend installing that. Therefore, I would like to add notes on readme to remind users to install it following the sglang doc here in advance if they want to adopt it as evaluation backend.

What do you think? I will list how do I plan to modify the dependencies below.

@@ -78,6 +78,7 @@ zeno = ["pandas", "zeno-client"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
gptqmodel = ["gptqmodel>=1.0.9"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
sglang =["sglang>=0.4.2.post2"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about deleting these partial dependencies. Install, I will remind the users in the readme to install sglang in advance.

):
super().__init__()

if not find_spec("sglang"):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, modify the error message here to remind the user for second time(apart from the readme):

raise ModuleNotFoundError(
                "attempted to use 'sglang' LM type, but package `sglang` is not installed. "
                "Please install sglang via official document here: https://docs.sglang.ai/start/install.html#method-1-with-pip"
            )

@zhaochenyang20
Copy link

Yeah. SGLang should be installed separately, but it's easy.

https://docs.sglang.ai/start/install.html

@baberabb
Copy link
Contributor

Hi @baberabb Thanks for your quick reply! Appreciate your comments:)

  1. as for readme of sglang, I update the Todos mentioned above;
  2. Flashinfer is a very fast attention backend adopted by SGLang and vLLM(recently), we highly recommend installing that. Therefore, I would like to add notes on readme to remind users to install it following the sglang doc here in advance if they want to adopt it as evaluation backend.

What do you think? I will list how do I plan to modify the dependencies below.

Hi! These seem reasonable. And yeah I agree will be best to add the installation instructions in init, in case not installed

@fxmarty-amd
Copy link

fxmarty-amd commented Feb 19, 2025

Slightly off-topic - what is the benefit of lm_eval --model sglang or lm_eval --model vllm compared to using OpenAI-like API, e.g. lm_eval --model local-completions --model_args model=meta-llama/Llama-2-7b-chat-hf,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=9999999,timeout=99999,tokenized_requests=False?

Are there certain evals that can not be done with local-completions? Otherwise, I think I could already use sglang as my runtime framework through local-completions, right?

@Monstertail
Copy link
Author

@fxmarty-amd Hi, it's true both vLLM and SGLang support OpenAI-like APIs, but it would be slower than offline batch inference(based on my previous tests). As in vLLM/SGLang docs, different server implementations would be provided to satisfy different requirements. We believe offline batch inference is needed as an option here as well.

As for the completeness of the functions, I think the API Server is basically complete for vLLM and SGLang.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants