Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Code style improvements #2355

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
Common utilities for torchao.
"""

from typing import Dict, Set

import torch


def apply_torchao_config_to_model_(
def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None
):
"""Quantize a modelwith torchao quantization specified by torchao_config
Expand All @@ -21,13 +19,19 @@ def apply_torchao_config_to_model_(
# Lazy import to suppress some warnings
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor

if filter_fn is None:

def filter_fn(module, fqn):
return "proj" in fqn

if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
Expand All @@ -44,8 +48,6 @@ def apply_torchao_config_to_model_(
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only

# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
if "FusedMoE" in sub.__class__.__name__:
if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
Expand Down
18 changes: 8 additions & 10 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
initialize_model_parallel,
set_custom_all_reduce,
)
from vllm.distributed.parallel_state import in_the_same_node_as

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
Expand All @@ -38,7 +37,7 @@
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
Expand Down Expand Up @@ -112,11 +111,13 @@ def __init__(
)

if self.is_multimodal:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
f"and turn off chunked prefill "
f"because this is a multimodal model."
)
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
Expand Down Expand Up @@ -160,11 +161,8 @@ def __init__(
else:
self.torch_tp_applied = False

def filter_fn(module, fqn):
return "proj" in fqn

apply_torchao_config_to_model_(
self.model, global_server_args_dict["torchao_config"], filter_fn
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)

# Init memory pool and attention backends
Expand Down
Loading