Skip to content

Commit

Permalink
Add torchao quant (int4/int8/fp8) to llama models (#1341)
Browse files Browse the repository at this point in the history
Co-authored-by: Lianmin Zheng <[email protected]>
  • Loading branch information
jerryzh168 and merrymercy authored Sep 9, 2024
1 parent e4d68af commit a7c47e0
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"torch", "torchao", "uvicorn", "uvloop", "zmq",
"vllm==0.5.5", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Common utilities for torchao.
"""

import torch
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)


def torchao_quantize_param_data(param, torchao_config):
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
if "int8wo" in torchao_config:
quantize_(dummy_linear, int8_weight_only())
elif "int8dq" in torchao_config:
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
elif "int4wo" in torchao_config:
group_size = int(torchao_config.split("-")[-1])
assert group_size in [
32,
64,
128,
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only

# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
return dummy_linear.weight
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
"torchao_config": server_args.torchao_config,
}
)

Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata


Expand Down Expand Up @@ -299,6 +301,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -361,6 +364,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if self.torchao_config:
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
if param_suffix in name:
param = params_dict[name]
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

self.load_state_dict(params_dict, assign=True)


class Phi3ForCausalLM(LlamaForCausalLM):
pass
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
torchao_config: str = ""
enable_p2p_check: bool = False
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
Expand Down Expand Up @@ -443,7 +444,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile, experimental feature.",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--torchao-config",
type=str,
default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
)
parser.add_argument(
"--enable-p2p-check",
Expand Down
4 changes: 2 additions & 2 deletions test/srt/test_eval_accuracy_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def test_mmlu(self):
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=32,
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.6
assert metrics["score"] >= 0.65


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_moe_eval_accuracy_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_mmlu(self):
)

metrics = run_eval(args)
assert metrics["score"] >= 0.62, f"{metrics}"
assert metrics["score"] >= 0.625, f"{metrics}"

def test_human_eval(self):
args = SimpleNamespace(
Expand All @@ -54,7 +54,7 @@ def test_human_eval(self):
)

metrics = run_eval(args)
assert metrics["score"] >= 0.42, f"{metrics}"
assert metrics["score"] >= 0.425, f"{metrics}"

def test_mgsm_en(self):
args = SimpleNamespace(
Expand All @@ -66,7 +66,7 @@ def test_mgsm_en(self):
)

metrics = run_eval(args)
assert metrics["score"] >= 0.62, f"{metrics}"
assert metrics["score"] >= 0.625, f"{metrics}"


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUpClass(cls):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-torch-compile", "--disable-radix-cache"],
other_args=["--enable-torch-compile"],
)

@classmethod
Expand All @@ -34,12 +34,12 @@ def test_mmlu(self):
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=32,
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.6
assert metrics["score"] >= 0.65

def run_decode(self, max_new_tokens):
response = requests.post(
Expand Down
73 changes: 73 additions & 0 deletions test/srt/test_torchao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestTorchCompile(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--torchao-config", "int4wo-128"],
)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.65

def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
"ignore_eos": True,
},
)
return response.json()

def test_throughput(self):
import time

max_tokens = 256

tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
print(res["text"])
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 210


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/srt/test_triton_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def test_mmlu(self):
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=32,
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.6
assert metrics["score"] >= 0.65


if __name__ == "__main__":
Expand Down

0 comments on commit a7c47e0

Please sign in to comment.