From feb2b768ba43577594c30f9dac55355954721e01 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 20 Dec 2024 08:25:25 -0800 Subject: [PATCH] Add integration with gemlite weight only quant (#2528) --- python/pyproject.toml | 2 +- python/sglang/bench_offline_throughput.py | 12 ++++++++ python/sglang/bench_one_batch.py | 13 +++++++++ python/sglang/srt/layers/torchao_utils.py | 35 +++++++++++++++++++++++ 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 5ff1ecb1bde..13ffdcccbc3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "orjson", "outlines>=0.0.44,<0.1.0", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", + "pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop", "xgrammar>=0.1.6"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"] diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f840ee878a8..fbb5a3fb304 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -322,6 +322,18 @@ def throughput_test( ) time.sleep(0.5) + try: + import os + import pwd + + from gemlite.core import GemLiteLinearTriton + + GemLiteLinearTriton.cache_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) + except ImportError: + pass + logging.info("\nBenchmark...") result = throughput_test_once( backend_name=bench_args.backend, diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e7a83139954..1a4dab7400f 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -385,6 +385,19 @@ def latency_test( 8, # shorter decoding to speed up the warmup server_args.device, ) + + try: + import os + import pwd + + from gemlite.core import GemLiteLinearTriton + + GemLiteLinearTriton.cache_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) + except ImportError: + pass + rank_print("Benchmark ...") # Run the sweep diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 910309da973..f911fe0a7c7 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -47,6 +47,41 @@ def filter_fn(module, fqn): 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + elif "gemlite" in torchao_config: + # gemlite--- or + # gemlite-- (packing_bitwidth defaults to 32) + import os + import pwd + + import gemlite + from gemlite.core import GemLiteLinearTriton, set_autotune + + try: + from torchao.quantization import gemlite_uintx_weight_only + except: + print( + f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization" + ) + return model + + _quant_args = torchao_config.split("-") + bit_width = int(_quant_args[-2]) + group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1]) + try: + packing_bitwidth = int(_quant_args[-3]) + except: + # if only 2 inputs found, use default value + packing_bitwidth = 32 + + quantize_( + model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth) + ) + + # try to load gemlite kernel config + GemLiteLinearTriton.load_config( + f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + ) + elif "fp8wo" in torchao_config: # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89