From d180fa5da9ed79924eb6ccb96f36a63278141bcc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 4 Dec 2024 18:33:44 -0800 Subject: [PATCH] [Minor] Code style improvements --- python/sglang/srt/layers/torchao_utils.py | 12 +++++++----- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 18 ++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 3f886221cca..910309da973 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,12 +2,10 @@ Common utilities for torchao. """ -from typing import Dict, Set - import torch -def apply_torchao_config_to_model_( +def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, filter_fn=None ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -21,6 +19,7 @@ def apply_torchao_config_to_model_( # Lazy import to suppress some warnings from torchao.quantization import ( float8_dynamic_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, @@ -28,6 +27,11 @@ def apply_torchao_config_to_model_( ) from torchao.quantization.observer import PerRow, PerTensor + if filter_fn is None: + + def filter_fn(module, fqn): + return "proj" in fqn + if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: @@ -44,8 +48,6 @@ def apply_torchao_config_to_model_( ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) elif "fp8wo" in torchao_config: - from torchao.quantization import float8_weight_only - # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 quantize_(model, float8_weight_only(), filter_fn=filter_fn) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index dd26a77ad65..3aac4965a5d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): if "FusedMoE" in sub.__class__.__name__: if batch_size == 1: # The performance of torch.compile on this layer is not always good when bs > 1, - # so we decide to skip it for now. + # so we decide to only use torch.compile when bs =1 sub._forward_method = fused_moe_forward_native else: sub._forward_method = sub.forward_native diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6f79afb70ba..4eaedbccbff 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -27,7 +27,6 @@ initialize_model_parallel, set_custom_all_reduce, ) -from vllm.distributed.parallel_state import in_the_same_node_as from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig @@ -38,7 +37,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler -from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ +from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -112,11 +111,13 @@ def __init__( ) if self.is_multimodal: - logger.info( - "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." - ) server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 + logger.info( + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} " + f"and turn off chunked prefill " + f"because this is a multimodal model." + ) # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" @@ -160,11 +161,8 @@ def __init__( else: self.torch_tp_applied = False - def filter_fn(module, fqn): - return "proj" in fqn - - apply_torchao_config_to_model_( - self.model, global_server_args_dict["torchao_config"], filter_fn + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] ) # Init memory pool and attention backends