From 959735fc9e38d6507651ba9196aa205430687b05 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 11 Dec 2024 05:21:23 -0800 Subject: [PATCH] Fix model loader for more quantization formats (#2448) --- python/sglang/srt/models/llama.py | 22 ++++++++++++++++++++++ python/sglang/srt/models/qwen2.py | 20 ++++++++++++++++++++ python/sglang/srt/server_args.py | 14 ++++++++++++-- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e3e44ea6ffc..71b4ed1b744 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -294,6 +294,28 @@ def forward( class LlamaForCausalLM(nn.Module): + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, config: LlamaConfig, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 9383fde4d09..2a20d6c50de 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -267,6 +267,26 @@ def forward( class Qwen2ForCausalLM(nn.Module): + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, config: Qwen2Config, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fe12d961d3a..902f24ebb7f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -283,7 +283,15 @@ def add_cli_args(parser: argparse.ArgumentParser): "--load-format", type=str, default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], + choices=[ + "auto", + "pt", + "safetensors", + "npcache", + "dummy", + "gguf", + "bitsandbytes", + ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -294,7 +302,9 @@ def add_cli_args(parser: argparse.ArgumentParser): "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ', + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization.", ) parser.add_argument( "--trust-remote-code",