Skip to content

Commit

Permalink
Add get weights by parameter name for llama (sgl-project#2266)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochenyang20 authored Nov 30, 2024
1 parent 7d5d1d3 commit 7d1485d
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ jobs:
python3 test_mla_fp8.py
python3 test_dp_attention.py
performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
Expand Down Expand Up @@ -242,6 +243,7 @@ jobs:
cd test/srt
python3 test_eval_accuracy_large.py
accuracy-test-2-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/amd/profiling/PROFILING.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
=======
-------
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput:
message: str


@dataclass
class GetWeightsByNameReqInput:
name: str
truncate_size: int = 100


@dataclass
class GetWeightsByNameReqOutput:
parameter: list


@dataclass
class AbortReq:
# The request id
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand Down Expand Up @@ -511,6 +513,9 @@ def process_input_requests(self, recv_reqs: List):
self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
Expand Down Expand Up @@ -1373,6 +1378,10 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
logger.error(message)
return success, message

def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter

def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand Down Expand Up @@ -454,6 +456,23 @@ async def update_weights_from_disk(
else:
return False, "Another update is in progress. Please try again later."

async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()

self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result
return result.parameter
else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters

async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
Expand Down Expand Up @@ -527,6 +546,7 @@ async def handle_loop(self):
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
GetWeightsByNameReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
Expand All @@ -538,6 +558,16 @@ async def handle_loop(self):
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
Expand Down Expand Up @@ -160,3 +163,9 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
recv_req.model_path, recv_req.load_format
)
return success, message

def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter
8 changes: 7 additions & 1 deletion python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import psutil
import torch

from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -208,6 +211,9 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message

def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)

def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
22 changes: 18 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
import json
import logging
import pkgutil
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
from typing import Optional, Type

import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
Expand Down Expand Up @@ -403,6 +400,23 @@ def model_load_weights(model, iter):
logger.info("Update weights end.")
return True, "Succeeded to update model weights."

def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
# TODO: (chenyang) Add support for Qwen models.
try:
return self.model.get_weights_by_name(
name, truncate_size, tp_size=self.tp_size
)
except Exception as e:
logger.error(f"Error when getting parameter {name}: {e}")
return None

def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
Expand Down
94 changes: 84 additions & 10 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""

import logging
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Expand Down Expand Up @@ -45,6 +46,8 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers

logger = logging.getLogger(__name__)


class LlamaMLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -305,6 +308,14 @@ def __init__(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]

@torch.no_grad()
def forward(
Expand Down Expand Up @@ -349,15 +360,7 @@ def get_module_name(self, name):
return params_mapping.get(name, name)

def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
Expand All @@ -370,6 +373,7 @@ def get_num_params(self):
return len(params_dict)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
Expand All @@ -378,6 +382,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]

params_dict = dict(self.named_parameters())

load_tie_word_embeddings = (
Expand Down Expand Up @@ -425,10 +430,79 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embed_tokens_weight)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try:
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name in name:
mapped_name = name.replace(weight_name, param_name)
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
if mapped_name in params_dict:
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
hidden_size = self.config.hidden_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size

weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
return None

except Exception as e:
logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
)
return None


class Phi3ForCausalLM(LlamaForCausalLM):
pass
Expand Down
Loading

0 comments on commit 7d1485d

Please sign in to comment.