Skip to content

Commit

Permalink
add hybrid kv
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Sep 16, 2024
1 parent ab4a83b commit 71c8afe
Show file tree
Hide file tree
Showing 16 changed files with 2,322 additions and 32 deletions.
16 changes: 12 additions & 4 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def from_cli_args(cls, args: argparse.Namespace):
)


def load_model(server_args, tp_rank):
def load_model(server_args, tp_rank, sp_rank: int = 0):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

Expand All @@ -130,6 +130,8 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
sp_rank=sp_rank,
sp_size=server_args.sp_size,
nccl_port=28888,
server_args=server_args,
)
Expand Down Expand Up @@ -206,6 +208,8 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
sp_size=model_runner.sp_size,
sp_rank=model_runner.sp_rank,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
Expand All @@ -225,11 +229,12 @@ def correctness_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
Expand Down Expand Up @@ -336,11 +341,12 @@ def latency_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
Expand Down Expand Up @@ -458,16 +464,18 @@ def main(server_args, bench_args):
)

if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
work_func(server_args, bench_args, 0, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
sp_rank = tp_rank % server_args.sp_size
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
sp_rank,
),
)
proc.start()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/parallel_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parallel_state import *
96 changes: 96 additions & 0 deletions python/sglang/srt/layers/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List, Optional

import torch
from vllm.distributed import initialize_model_parallel as vllm_initialize_model_parallel
from vllm.distributed.parallel_state import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_world_group,
init_model_parallel_group,
)

_SP: Optional[GroupCoordinator] = None


def get_sp_group():
assert _SP is not None, "sequence parallel group is not initialized"
return _SP


def init_sequence_parallel_group(
group_ranks: List[List[int]], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
)


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
sequence_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups and sequence parallel groups.
For sequence parallelism, we partition SP groups within a TP group, and assign
gpus with adjacent ranks to the same SP group. For example, with TP size 8
and SP size 2, we have 1 TP group and 4 SP groups:
SP groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
Their KV TP rank:
[ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3]
Given that we replicate KV heads within the same seq parallel group, we also say that
the KV TP size is 4 (8//2), and gpus in each SP group have KV-tp rank from 0 to 3.
"""
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)

num_sequence_parallel_groups: int = world_size // sequence_parallel_size
global _SP
assert _SP is None, "sequence parallel group is already initialized"
group_ranks = []
for i in range(num_sequence_parallel_groups):
ranks = list(
range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
)
group_ranks.append(ranks)
_SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend)

vllm_initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend
)


def sequence_parallel_is_initialized():
return _SP is not None


def get_sequence_parallel_world_size():
return get_sp_group().world_size


def get_sequence_parallel_rank():
return get_sp_group().rank_in_group


def get_sequence_parallel_global_rank():
return get_tensor_model_parallel_rank()


# NOTE: For sequence parallelism, we partition Q tensors along the head dimension.
# But K/V tensors are partitioned along the head dimension in TP and partitioned
# along the sequence dimensions in SP. Therefore, their TP size and rank is adjusted
# accordingly as below.
def get_kv_tensor_model_parallel_world_size():
return get_tensor_model_parallel_world_size() // get_sequence_parallel_world_size()


def get_kv_tensor_model_parallel_rank():
return get_tensor_model_parallel_rank() // get_sequence_parallel_world_size()
Loading

0 comments on commit 71c8afe

Please sign in to comment.