From 9cc733b38ceb4fc9df0daa6aed7335f2f8a4ba82 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 4 Dec 2024 17:26:42 -0800 Subject: [PATCH] move apply_torchao_config_ to model_runner (#2342) --- python/sglang/srt/layers/torchao_utils.py | 58 ++++++------------- .../sglang/srt/model_executor/model_runner.py | 8 +++ python/sglang/srt/models/grok.py | 5 -- python/sglang/srt/models/llama.py | 5 -- python/sglang/srt/models/mixtral.py | 5 -- python/sglang/srt/models/phi3_small.py | 5 -- python/sglang/srt/models/qwen2_moe.py | 5 -- .../sglang/srt/models/torch_native_llama.py | 5 -- 8 files changed, 25 insertions(+), 71 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 9395cdf271b..3f886221cca 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -7,13 +7,15 @@ import torch -def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): - """Quantize a Tensor with torchao quantization specified by torchao_config +def apply_torchao_config_to_model_( + model: torch.nn.Module, torchao_config: str, filter_fn=None +): + """Quantize a modelwith torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to use to - quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + `model`: a model to be quantized based on torchao_config + `torchao_config` (str): type of quantization and their arguments we want to use to + quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size 128 """ # Lazy import to suppress some warnings @@ -26,12 +28,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): ) from torchao.quantization.observer import PerRow, PerTensor - dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) - dummy_linear.weight = param - if "int8wo" in torchao_config: - quantize_(dummy_linear, int8_weight_only()) + if torchao_config == "" or torchao_config is None: + return model + elif "int8wo" in torchao_config: + quantize_(model, int8_weight_only(), filter_fn=filter_fn) elif "int8dq" in torchao_config: - quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: group_size = int(torchao_config.split("-")[-1]) assert group_size in [ @@ -40,13 +42,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: from torchao.quantization import float8_weight_only # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(dummy_linear, float8_weight_only()) + quantize_(model, float8_weight_only(), filter_fn=filter_fn) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -57,39 +59,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): granularity in GRANULARITY_MAP ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" quantize_( - dummy_linear, + model, float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), + filter_fn=filter_fn, ) else: raise ValueError(f"Unexpected config: {torchao_config}") - return dummy_linear.weight - - -def apply_torchao_config_( - self: torch.nn.Module, - params_dict: Dict[str, torch.Tensor], - param_suffixes: Set[str], -) -> None: - """A util function used for quantizing the weight parameters after they are loaded if - self.torchao_config is specified - - Args: - `self`: the model we want to quantize - `params_dict`: dictionary mapping from param_name to the parameter Tensor - `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes - - Returns: - None, the `params_dict` is modified inplace and the weights of `self` model are quantized - """ - if self.torchao_config: - for param_suffix in param_suffixes: - for name in params_dict: - param = params_dict[name] - if param_suffix in name and param.ndim == 2: - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - self.load_state_dict(params_dict, assign=True) + return model diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fafb8783e5a..6f79afb70ba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -159,6 +160,13 @@ def __init__( else: self.torch_tp_applied = False + def filter_fn(module, fqn): + return "proj" in fqn + + apply_torchao_config_to_model_( + self.model, global_server_args_dict["torchao_config"], filter_fn + ) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 956f73b1482..2b52e2b1bcc 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -35,12 +35,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -290,7 +288,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class Grok1ModelForCausalLM(Grok1ForCausalLM): """An alias for backward-compatbility.""" diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 61409a9eaeb..e3e44ea6ffc 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -36,12 +36,10 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -304,7 +302,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) # Llama 3.2 1B Insturct set tie_word_embeddings to True # Llama 3.1 8B Insturct set tie_word_embeddings to False @@ -424,8 +421,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - def get_weights_by_name( self, name: str, truncate_size: int = 100, tp_size: int = 1 ) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index e75dc1288b7..f1ae1f57a3d 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -34,12 +34,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -295,7 +293,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -387,7 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 6340330774d..1e70c7d7874 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -17,13 +17,11 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers @@ -348,7 +346,6 @@ def __init__( quant_config=quant_config, prefix="model", ) - self.torchao_config = global_server_args_dict["torchao_config"] self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier self.lm_head = ParallelLMHead( @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Phi3SmallForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0094cb8c3e2..62cd3281d03 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -40,12 +40,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -352,7 +350,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -445,7 +442,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 25e555484a7..7a55d50457a 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -58,12 +58,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -392,7 +390,6 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) if self.config.tie_word_embeddings: @@ -503,8 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass