From f61b2123e6177c22f792304dd2fca6f313a37ae4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 3 Dec 2024 18:56:37 -0800 Subject: [PATCH 1/6] move apply_torchao_config_ to model_runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Previously we need to apply_torchao_config_ to each model manually, this PR changes it to run on the entire model, we can also add autoquant in the future Test Plan: llama: python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile Benchmark ... Prefill. latency: 0.03361 s, throughput: 3808.05 token/s Decode. latency: 0.01227 s, throughput: 81.50 token/s Decode. latency: 0.01195 s, throughput: 83.70 token/s Decode. latency: 0.01181 s, throughput: 84.65 token/s Decode. latency: 0.01176 s, throughput: 85.05 token/s Decode. latency: 0.01133 s, throughput: 88.25 token/s Decode. median latency: 0.01176 s, median throughput: 85.05 token/s Total. latency: 0.115 s, throughput: 1179.56 token/s python3 -m sglang.bench_one_batch --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --json-model-override-arg s '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --enable-torch-compile —torchao-config int4wo-128 Benchmark ... Prefill. latency: 0.11769 s, throughput: 1087.60 token/s Decode. latency: 0.00687 s, throughput: 145.47 token/s Decode. latency: 0.00648 s, throughput: 154.25 token/s Decode. latency: 0.00641 s, throughput: 156.01 token/s Decode. latency: 0.00635 s, throughput: 157.53 token/s Decode. latency: 0.00634 s, throughput: 157.74 token/s Decode. median latency: 0.00644 s, median throughput: 155.28 token/s Total. latency: 0.163 s, throughput: 834.21 token/s qwen: python3 -m sglang.bench_one_batch --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int4wo-128 
original: Benchmark ... Prefill. latency: 0.06101 s, throughput: 2097.86 token/s Decode. latency: 0.00532 s, throughput: 187.93 token/s Decode. latency: 0.00524 s, throughput: 190.88 token/s Decode. latency: 0.00520 s, throughput: 192.43 token/s Decode. latency: 0.00513 s, throughput: 194.97 token/s Decode. latency: 0.00507 s, throughput: 197.26 token/s Decode. median latency: 0.00513 s, median throughput: 194.97 token/s Total. latency: 0.097 s, throughput: 1400.16 token/s after change:
 Benchmark ... Prefill. latency: 0.05830 s, throughput: 2195.38 token/s Decode. latency: 0.00517 s, throughput: 193.50 token/s Decode. latency: 0.00508 s, throughput: 196.71 token/s Decode. latency: 0.00512 s, throughput: 195.36 token/s Decode. latency: 0.00508 s, throughput: 196.97 token/s Decode. latency: 0.00504 s, throughput: 198.44 token/s Decode. median latency: 0.00508 s, median throughput: 196.97 token/s Total. latency: 0.094 s, throughput: 1449.19 token/s Reviewers: Subscribers: Tasks: Tags: --- python/sglang/srt/layers/torchao_utils.py | 48 ++++++++++++++----- .../sglang/srt/model_executor/model_runner.py | 7 +++ python/sglang/srt/models/llama.py | 4 -- python/sglang/srt/models/qwen2_moe.py | 4 -- .../sglang/srt/models/torch_native_llama.py | 4 -- 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 9395cdf271b..a84f813eae4 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -7,13 +7,13 @@ import torch -def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): - """Quantize a Tensor with torchao quantization specified by torchao_config +def apply_torchao_config_to_model_(module: torch.nn.Module, torchao_config: str, filter_fn=None): + """Quantize a Module with torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to use to - quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + `module`: a module to be quantized based on torchao_config + `torchao_config` (str): type of quantization and their arguments we want to use to + quantize the module, e.g. int4wo-128 means int4 weight only quantization with group_size 128 """ # Lazy import to suppress some warnings @@ -26,12 +26,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): ) from torchao.quantization.observer import PerRow, PerTensor - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) - dummy_linear.weight = param - if "int8wo" in torchao_config: - quantize_(dummy_linear, int8_weight_only()) + if torchao_config == "" or torchao_config is None: + return module + elif "int8wo" in torchao_config: + quantize_(module, int8_weight_only(), filter_fn=filter_fn) elif "int8dq" in torchao_config: - quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + quantize_(module, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: group_size = int(torchao_config.split("-")[-1]) assert group_size in [ @@ -40,13 +40,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + quantize_(module, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: from torchao.quantization import float8_weight_only # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(dummy_linear, float8_weight_only()) + quantize_(module, float8_weight_only(), filter_fn=filter_fn) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -57,14 +57,31 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): granularity in GRANULARITY_MAP ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" quantize_( - dummy_linear, + module, float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), + filter_fn=filter_fn, ) else: raise ValueError(f"Unexpected config: {torchao_config}") + return module + + +def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): + """Quantize a Tensor with torchao quantization specified by torchao_config + + Args: + `param`: weight parameter of the linear module + `torchao_config` (str): type of quantization and their arguments we want to use to + quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + 128 + """ + with torch.device("meta"): + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + dummy_linear.weight = torch.nn.Parameter(param) + apply_torchao_config_to_model_(dummy_linear, torchao_config) return dummy_linear.weight @@ -93,3 +110,8 @@ def apply_torchao_config_( param, self.torchao_config ) self.load_state_dict(params_dict, assign=True) + +# def apply_torchao_config_to_model_( +# self: torch.nn.Module +# ): +# torchao_quantize_module_(self, self.torchao_config) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fafb8783e5a..da11cabb291 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -59,6 +59,8 @@ monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ +from sglang.srt.managers.schedule_batch import global_server_args_dict logger = logging.getLogger(__name__) @@ -159,6 +161,11 @@ def __init__( else: self.torch_tp_applied = False + def filter_fn(module, fqn): + return "proj" in fqn + + apply_torchao_config_to_model_(self.model, global_server_args_dict["torchao_config"], filter_fn) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 61409a9eaeb..8749d1310df 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -36,7 +36,6 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -304,7 +303,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) # Llama 3.2 1B Insturct set tie_word_embeddings to True # Llama 3.1 8B Insturct set tie_word_embeddings to False @@ -424,8 +422,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - def get_weights_by_name( self, name: str, truncate_size: int = 100, tp_size: int = 1 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0094cb8c3e2..ca77abeff06 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -40,7 +40,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -352,7 +351,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -445,7 +443,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 25e555484a7..01d7f57e0c3 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -58,7 +58,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -392,7 +391,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) if self.config.tie_word_embeddings: @@ -503,8 +501,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass From 9c80d4b33e9f09e767573560369895d82e736556 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 3 Dec 2024 19:08:41 -0800 Subject: [PATCH 2/6] remove unused --- python/sglang/srt/layers/torchao_utils.py | 54 ++--------------------- 1 file changed, 3 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index a84f813eae4..6d4a419ea29 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -7,11 +7,11 @@ import torch -def apply_torchao_config_to_model_(module: torch.nn.Module, torchao_config: str, filter_fn=None): - """Quantize a Module with torchao quantization specified by torchao_config +def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, filter_fn=None): + """Quantize a modelwith torchao quantization specified by torchao_config Args: - `module`: a module to be quantized based on torchao_config + `model`: a model to be quantized based on torchao_config `torchao_config` (str): type of quantization and their arguments we want to use to quantize the module, e.g. int4wo-128 means int4 weight only quantization with group_size 128 @@ -67,51 +67,3 @@ def apply_torchao_config_to_model_(module: torch.nn.Module, torchao_config: str, raise ValueError(f"Unexpected config: {torchao_config}") return module - - -def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): - """Quantize a Tensor with torchao quantization specified by torchao_config - - Args: - `param`: weight parameter of the linear module - `torchao_config` (str): type of quantization and their arguments we want to use to - quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size - 128 - """ - with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) - dummy_linear.weight = torch.nn.Parameter(param) - apply_torchao_config_to_model_(dummy_linear, torchao_config) - return dummy_linear.weight - - -def apply_torchao_config_( - self: torch.nn.Module, - params_dict: Dict[str, torch.Tensor], - param_suffixes: Set[str], -) -> None: - """A util function used for quantizing the weight parameters after they are loaded if - self.torchao_config is specified - - Args: - `self`: the model we want to quantize - `params_dict`: dictionary mapping from param_name to the parameter Tensor - `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes - - Returns: - None, the `params_dict` is modified inplace and the weights of `self` model are quantized - """ - if self.torchao_config: - for param_suffix in param_suffixes: - for name in params_dict: - param = params_dict[name] - if param_suffix in name and param.ndim == 2: - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - self.load_state_dict(params_dict, assign=True) - -# def apply_torchao_config_to_model_( -# self: torch.nn.Module -# ): -# torchao_quantize_module_(self, self.torchao_config) From 3af5dd369da55ee29db45815e1b61424b146ceb7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 3 Dec 2024 19:10:58 -0800 Subject: [PATCH 3/6] remove old apply --- python/sglang/srt/models/grok.py | 5 ----- python/sglang/srt/models/llama.py | 1 - python/sglang/srt/models/mixtral.py | 5 ----- python/sglang/srt/models/phi3_small.py | 5 ----- python/sglang/srt/models/qwen2_moe.py | 1 - python/sglang/srt/models/torch_native_llama.py | 1 - 6 files changed, 18 deletions(-) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 956f73b1482..2b52e2b1bcc 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -35,12 +35,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -290,7 +288,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class Grok1ModelForCausalLM(Grok1ForCausalLM): """An alias for backward-compatbility.""" diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 8749d1310df..e3e44ea6ffc 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -40,7 +40,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index e75dc1288b7..f1ae1f57a3d 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -34,12 +34,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -295,7 +293,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -387,7 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 6340330774d..1e70c7d7874 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -17,13 +17,11 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -348,7 +346,6 @@ def __init__( quant_config=quant_config, prefix="model", ) - self.torchao_config = global_server_args_dict["torchao_config"] self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier self.lm_head = ParallelLMHead( @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Phi3SmallForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index ca77abeff06..62cd3281d03 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -44,7 +44,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 01d7f57e0c3..7a55d50457a 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -62,7 +62,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader From 2bd5f8d7a7134ca6a5e50d46ebdb36ae6452fa64 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 3 Dec 2024 19:11:39 -0800 Subject: [PATCH 4/6] remove dup import --- python/sglang/srt/model_executor/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index da11cabb291..ca52b3e4909 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,7 +60,6 @@ set_cpu_offload_max_bytes, ) from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ -from sglang.srt.managers.schedule_batch import global_server_args_dict logger = logging.getLogger(__name__) From cc06d2c385d35aa96f068f051dcddc722945a1ba Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 12:35:37 -0800 Subject: [PATCH 5/6] fix typo --- python/sglang/srt/layers/torchao_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 6d4a419ea29..69100fd0437 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -13,7 +13,7 @@ def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, Args: `model`: a model to be quantized based on torchao_config `torchao_config` (str): type of quantization and their arguments we want to use to - quantize the module, e.g. int4wo-128 means int4 weight only quantization with group_size + quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size 128 """ # Lazy import to suppress some warnings @@ -27,11 +27,11 @@ def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, from torchao.quantization.observer import PerRow, PerTensor if torchao_config == "" or torchao_config is None: - return module + return model elif "int8wo" in torchao_config: - quantize_(module, int8_weight_only(), filter_fn=filter_fn) + quantize_(model, int8_weight_only(), filter_fn=filter_fn) elif "int8dq" in torchao_config: - quantize_(module, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: group_size = int(torchao_config.split("-")[-1]) assert group_size in [ @@ -40,13 +40,13 @@ def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(module, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: from torchao.quantization import float8_weight_only # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(module, float8_weight_only(), filter_fn=filter_fn) + quantize_(model, float8_weight_only(), filter_fn=filter_fn) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -57,7 +57,7 @@ def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, granularity in GRANULARITY_MAP ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" quantize_( - module, + model, float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), @@ -66,4 +66,4 @@ def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, else: raise ValueError(f"Unexpected config: {torchao_config}") - return module + return model From 0d569fe3c29b3f94eb008cb8a3283707c0702357 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 16:57:46 -0800 Subject: [PATCH 6/6] format --- python/sglang/srt/layers/torchao_utils.py | 4 +++- python/sglang/srt/model_executor/model_runner.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 69100fd0437..3f886221cca 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -7,7 +7,9 @@ import torch -def apply_torchao_config_to_model_(model: torch.nn.Module, torchao_config: str, filter_fn=None): +def apply_torchao_config_to_model_( + model: torch.nn.Module, torchao_config: str, filter_fn=None +): """Quantize a modelwith torchao quantization specified by torchao_config Args: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ca52b3e4909..6f79afb70ba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -59,7 +60,6 @@ monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) -from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ logger = logging.getLogger(__name__) @@ -163,7 +163,9 @@ def __init__( def filter_fn(module, fqn): return "proj" in fqn - apply_torchao_config_to_model_(self.model, global_server_args_dict["torchao_config"], filter_fn) + apply_torchao_config_to_model_( + self.model, global_server_args_dict["torchao_config"], filter_fn + ) # Init memory pool and attention backends if server_args.lora_paths is not None: