Skip to content

Commit

Permalink
Add integration with gemlite weight only quant (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored Dec 20, 2024
1 parent d95a5f5 commit feb2b76
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop",
"xgrammar>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"]

Expand Down
12 changes: 12 additions & 0 deletions python/sglang/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ def throughput_test(
)
time.sleep(0.5)

try:
import os
import pwd

from gemlite.core import GemLiteLinearTriton

GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass

logging.info("\nBenchmark...")
result = throughput_test_once(
backend_name=bench_args.backend,
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,19 @@ def latency_test(
8, # shorter decoding to speed up the warmup
server_args.device,
)

try:
import os
import pwd

from gemlite.core import GemLiteLinearTriton

GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass

rank_print("Benchmark ...")

# Run the sweep
Expand Down
35 changes: 35 additions & 0 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,41 @@ def filter_fn(module, fqn):
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "gemlite" in torchao_config:
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
import os
import pwd

import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune

try:
from torchao.quantization import gemlite_uintx_weight_only
except:
print(
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
)
return model

_quant_args = torchao_config.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
packing_bitwidth = 32

quantize_(
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
)

# try to load gemlite kernel config
GemLiteLinearTriton.load_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)

elif "fp8wo" in torchao_config:
# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
Expand Down

0 comments on commit feb2b76

Please sign in to comment.