Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move apply_torchao_config_ to model_runner #2342

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import torch


def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
"""Quantize a Tensor with torchao quantization specified by torchao_config
def apply_torchao_config_to_model_(
model: torch.nn.Module, torchao_config: str, filter_fn=None
):
"""Quantize a modelwith torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
`model`: a model to be quantized based on torchao_config
`torchao_config` (str): type of quantization and their arguments we want to use to
quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
Expand All @@ -26,12 +28,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
)
from torchao.quantization.observer import PerRow, PerTensor

dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
if "int8wo" in torchao_config:
quantize_(dummy_linear, int8_weight_only())
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
elif "int8dq" in torchao_config:
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
elif "int4wo" in torchao_config:
group_size = int(torchao_config.split("-")[-1])
assert group_size in [
Expand All @@ -40,13 +42,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
128,
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only

# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
elif "fp8dq" in torchao_config:
granularity = torchao_config.split("-")[-1]
GRANULARITY_MAP = {
Expand All @@ -57,39 +59,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
granularity in GRANULARITY_MAP
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
quantize_(
dummy_linear,
model,
float8_dynamic_activation_float8_weight(
granularity=GRANULARITY_MAP[granularity]
),
filter_fn=filter_fn,
)
else:
raise ValueError(f"Unexpected config: {torchao_config}")

return dummy_linear.weight


def apply_torchao_config_(
self: torch.nn.Module,
params_dict: Dict[str, torch.Tensor],
param_suffixes: Set[str],
) -> None:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if self.torchao_config:
for param_suffix in param_suffixes:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
return model
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
Expand Down Expand Up @@ -159,6 +160,13 @@ def __init__(
else:
self.torch_tp_applied = False

def filter_fn(module, fqn):
return "proj" in fqn

apply_torchao_config_to_model_(
self.model, global_server_args_dict["torchao_config"], filter_fn
)

# Init memory pool and attention backends
if server_args.lora_paths is not None:
self.init_lora_manager()
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -290,7 +288,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -374,8 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Expand Down Expand Up @@ -304,7 +302,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
Expand Down Expand Up @@ -424,8 +421,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -295,7 +293,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -387,7 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = MixtralForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Expand Down Expand Up @@ -348,7 +346,6 @@ def __init__(
quant_config=quant_config,
prefix="model",
)
self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead(
Expand Down Expand Up @@ -441,7 +438,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = Phi3SmallForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -352,7 +350,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
Expand Down Expand Up @@ -445,7 +442,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


EntryClass = Qwen2MoeForCausalLM
5 changes: 0 additions & 5 deletions python/sglang/srt/models/torch_native_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader

Expand Down Expand Up @@ -392,7 +390,6 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
Expand Down Expand Up @@ -503,8 +500,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))


class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
Expand Down
Loading