From 3df95a3e945fa8a98234e5b12cceeb6299419a35 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 16 Dec 2024 15:18:46 -0800 Subject: [PATCH 1/6] Add integration with gemlite weight only quant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: gemlite Only available with nightly torchao right now (or install from source) Test Plan: ``` python3 -m sglang.bench_one_batch --model meta-llama/Llama-3.1-8B-Instruct --batch-size 1 --input 1024 --output 512 --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile —torchao-config gemlite-4-64 --tp-size 1 ``` Reviewers: Subscribers: Tasks: Tags: --- python/sglang/bench_offline_throughput.py | 10 +++++++++ python/sglang/bench_one_batch.py | 8 +++++++ python/sglang/srt/layers/torchao_utils.py | 26 +++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f840ee878a8..a1453b81d2a 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -309,6 +309,9 @@ def throughput_test( dataset_path=bench_args.dataset_path, ) + import os, pwd + print(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + # Warm up if not bench_args.skip_warmup: logging.info("\nWarmup...") @@ -322,6 +325,13 @@ def throughput_test( ) time.sleep(0.5) + try: + from gemlite.core import GemLiteLinearTriton + import os, pwd + GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except ImportError: + pass + logging.info("\nBenchmark...") result = throughput_test_once( backend_name=bench_args.backend, diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e7a83139954..6b3620f1217 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -385,6 +385,14 @@ def latency_test( 8, # shorter decoding to speed up the warmup server_args.device, ) + + try: + from gemlite.core import GemLiteLinearTriton + import os, pwd + GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except ImportError: + pass + rank_print("Benchmark ...") # Run the sweep diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 910309da973..c798ecac92c 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -47,6 +47,32 @@ def filter_fn(module, fqn): 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + elif "gemlite" in torchao_config: + # gemlite--- or + # gemlite-- (packing_bitwidth defaults to 32) + import os, pwd + import gemlite + from gemlite.core import GemLiteLinearTriton, set_autotune + from torchao.quantization import gemlite_uintx_weight_only + + _quant_args = torchao_config.split("-") + bit_width = int(_quant_args[-2]) + group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) + try: + packing_bitwidth = int(_quant_args[-3]) + except: + # if only 2 inputs found, use default value + packing_bitwidth = 32 + + quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)) + + # try to load gemlite kernel config + try: + GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except: + print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + elif "fp8wo" in torchao_config: # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 From 6aabf04a8af08097ad80701d80cdd1ca3dbd8fc9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Dec 2024 10:46:14 -0800 Subject: [PATCH 2/6] formatting --- python/sglang/bench_offline_throughput.py | 12 +++++++----- python/sglang/bench_one_batch.py | 9 +++++++-- python/sglang/srt/layers/torchao_utils.py | 22 ++++++++++++++++------ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index a1453b81d2a..fbb5a3fb304 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -309,9 +309,6 @@ def throughput_test( dataset_path=bench_args.dataset_path, ) - import os, pwd - print(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") - # Warm up if not bench_args.skip_warmup: logging.info("\nWarmup...") @@ -326,9 +323,14 @@ def throughput_test( time.sleep(0.5) try: + import os + import pwd + from gemlite.core import GemLiteLinearTriton - import os, pwd - GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + + GemLiteLinearTriton.cache_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) except ImportError: pass diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 6b3620f1217..1a4dab7400f 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -387,9 +387,14 @@ def latency_test( ) try: + import os + import pwd + from gemlite.core import GemLiteLinearTriton - import os, pwd - GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + + GemLiteLinearTriton.cache_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) except ImportError: pass diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index c798ecac92c..136681026f1 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -50,28 +50,38 @@ def filter_fn(module, fqn): elif "gemlite" in torchao_config: # gemlite--- or # gemlite-- (packing_bitwidth defaults to 32) - import os, pwd + import os + import pwd + import gemlite from gemlite.core import GemLiteLinearTriton, set_autotune from torchao.quantization import gemlite_uintx_weight_only _quant_args = torchao_config.split("-") bit_width = int(_quant_args[-2]) - group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) + group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1]) try: packing_bitwidth = int(_quant_args[-3]) except: # if only 2 inputs found, use default value packing_bitwidth = 32 - quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)) + quantize_( + model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth) + ) # try to load gemlite kernel config try: - GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") - print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + GemLiteLinearTriton.load_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) + print( + f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) except: - print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + print( + f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) elif "fp8wo" in torchao_config: # this requires newer hardware From c448c778833a179de127c679dc535968cf1b10e5 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Dec 2024 10:53:59 -0800 Subject: [PATCH 3/6] format --- python/sglang/srt/layers/torchao_utils.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 136681026f1..5e30ccf9bda 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -71,17 +71,9 @@ def filter_fn(module, fqn): ) # try to load gemlite kernel config - try: - GemLiteLinearTriton.load_config( - f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) - print( - f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) - except: - print( - f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" - ) + GemLiteLinearTriton.load_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) elif "fp8wo" in torchao_config: # this requires newer hardware From 21cd4598d48a41f77de69051e61b84e28810aff9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Dec 2024 14:57:35 -0800 Subject: [PATCH 4/6] add gemlite dep --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6267bc523cf..d28569f8d38 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "orjson", "outlines>=0.0.44,<0.1.0", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", + "pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop", "xgrammar>=0.1.6"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"] From bb9f13348baf235b471db5ae0500338f57adba6f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Dec 2024 15:01:35 -0800 Subject: [PATCH 5/6] add error checks --- python/sglang/srt/layers/torchao_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 5e30ccf9bda..7fc5f2eb9d1 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -55,7 +55,11 @@ def filter_fn(module, fqn): import gemlite from gemlite.core import GemLiteLinearTriton, set_autotune - from torchao.quantization import gemlite_uintx_weight_only + try: + from torchao.quantization import gemlite_uintx_weight_only + except: + print(f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization") + return model _quant_args = torchao_config.split("-") bit_width = int(_quant_args[-2]) From 88f4ecead4141993acc154a6a762ef634e244265 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 19 Dec 2024 15:07:54 -0800 Subject: [PATCH 6/6] format --- python/sglang/srt/layers/torchao_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 7fc5f2eb9d1..f911fe0a7c7 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -55,10 +55,13 @@ def filter_fn(module, fqn): import gemlite from gemlite.core import GemLiteLinearTriton, set_autotune + try: from torchao.quantization import gemlite_uintx_weight_only except: - print(f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization") + print( + f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization" + ) return model _quant_args = torchao_config.split("-")