diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 910309da973..1fdda4fad45 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -26,11 +26,12 @@ def apply_torchao_config_to_model( quantize_, ) from torchao.quantization.observer import PerRow, PerTensor + from torchao.quantization.quant_api import _is_linear if filter_fn is None: def filter_fn(module, fqn): - return "proj" in fqn + return _is_linear(module) and "proj" in fqn if torchao_config == "" or torchao_config is None: return model diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index db024c5c7fb..a3f62f250ea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -157,6 +157,10 @@ def __init__( self.sampler = Sampler() self.load_model() + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: @@ -165,10 +169,6 @@ def __init__( else: self.torch_tp_applied = False - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) - # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager()