Skip to content

Commit

Permalink
move apply_torchao_config_ to model_runner (#2342)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored Dec 5, 2024
1 parent d693ec0 commit 9cc733b
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 71 deletions.
58 changes: 17 additions & 41 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import torch


def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
"""Quantize a Tensor with torchao quantization specified by torchao_config
def apply_torchao_config_to_model_(
model: torch.nn.Module, torchao_config: str, filter_fn=None
):
"""Quantize a modelwith torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
`model`: a model to be quantized based on torchao_config
`torchao_config` (str): type of quantization and their arguments we want to use to
quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
Expand All @@ -26,12 +28,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
)
from torchao.quantization.observer import PerRow, PerTensor

dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
if "int8wo" in torchao_config:
quantize_(dummy_linear, int8_weight_only())
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
elif "int8dq" in torchao_config:
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
elif "int4wo" in torchao_config:
group_size = int(torchao_config.split("-")[-1])
assert group_size in [
Expand All @@ -40,13 +42,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
128,
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only

# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
elif "fp8dq" in torchao_config:
granularity = torchao_config.split("-")[-1]
GRANULARITY_MAP = {
Expand All @@ -57,39 +59,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
granularity in GRANULARITY_MAP
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
quantize_(
dummy_linear,
model,
float8_dynamic_activation_float8_weight(
granularity=GRANULARITY_MAP[granularity]
),
filter_fn=filter_fn,
)
else:
raise ValueError(f"Unexpected config: {torchao_config}")

return dummy_linear.weight


def apply_torchao_config_(
self: torch.nn.Module,
params_dict: Dict[str, torch.Tensor],
param_suffixes: Set[str],
) -> None:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if self.torchao_config:
for param_suffix in param_suffixes:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
return model
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
Expand Down Expand Up @@ -159,6 +160,13 @@ def __init__(
else:
self.torch_tp_applied = False

def filter_fn(module, fqn):
return "proj" in fqn

apply_torchao_config_to_model_(
self.model, global_server_args_dict["torchao_config"], filter_fn
)

# Init memory pool and attention backends
if server_args.lora_paths is not None:
self.init_lora_manager()
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -290,7 +288,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Expand Down Expand Up @@ -304,7 +302,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
Expand Down Expand Up @@ -424,8 +421,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -295,7 +293,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -387,7 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = MixtralForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Expand Down Expand Up @@ -348,7 +346,6 @@ def __init__(
quant_config=quant_config,
prefix="model",
)
self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead(
Expand Down Expand Up @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = Phi3SmallForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -352,7 +350,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
Expand Down Expand Up @@ -445,7 +442,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = Qwen2MoeForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/torch_native_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -392,7 +390,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
Expand Down Expand Up @@ -503,8 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
Expand Down

0 comments on commit 9cc733b

Please sign in to comment.