Skip to content

Commit

Permalink
Support DP MLA (#1970)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Nov 16, 2024
1 parent 2f2e074 commit 976bc30
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 63 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ jobs:
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
- name: Evaluate data parallelism accuracy (DP=2)
timeout-minutes: 10
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def __init__(self, model_runner: ModelRunner):

self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

if model_runner.server_args.enable_dp_attention:
self.num_head = model_runner.model_config.num_attention_heads
else:
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
Expand Down
51 changes: 43 additions & 8 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,34 @@ def __init__(self, server_args, port_args) -> None:
# Start data parallel workers
base_gpu_id = 0
self.workers = []
scheduler_pipe_readers = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name

send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)

if server_args.enable_dp_attention:
# Share workers for DP and TP
send_to, reader = self.launch_tensor_parallel_process(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += 1
scheduler_pipe_readers.append(reader)
else:
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += server_args.tp_size
self.workers.append(send_to)
base_gpu_id += server_args.tp_size

for reader in scheduler_pipe_readers:
reader.recv()

def launch_tensor_parallel_group(
self,
Expand Down Expand Up @@ -132,6 +146,27 @@ def launch_tensor_parallel_group(

return send_to

def launch_tensor_parallel_process(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id
tp_rank = dp_rank
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)

return send_to, reader

def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
Expand Down
29 changes: 24 additions & 5 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config,
"disable_nan_detection": ServerArgs.disable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
}


Expand Down Expand Up @@ -450,6 +451,9 @@ class ScheduleBatch:
# The sum of all sequence lengths
seq_lens_sum: int = None

# For DP attention
global_num_tokens: Optional[List[int]] = None

# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
Expand Down Expand Up @@ -858,6 +862,16 @@ def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)

def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = 0

def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE

Expand Down Expand Up @@ -969,17 +983,18 @@ def merge_batch(self, other: "ScheduleBatch"):
self.has_grammar = self.has_grammar or other.has_grammar

def get_model_worker_batch(self):
if self.forward_mode.is_decode():
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens

if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
if self.sampling_info is not None:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None

global bid
bid += 1
Expand All @@ -995,6 +1010,7 @@ def get_model_worker_batch(self):
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
Expand Down Expand Up @@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
return_logprob: bool
top_logprobs_nums: Optional[List[int]]

# For DP attention
global_num_tokens: Optional[List[int]]

# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
Expand Down
58 changes: 55 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
# Init inter-process communication
context = zmq.Context(2)

if self.tp_rank == 0:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name
)
Expand Down Expand Up @@ -347,6 +347,10 @@ def event_loop_normal(self):
self.process_input_requests(recv_reqs)

batch = self.get_next_batch_to_run()

if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)

self.cur_batch = batch

if batch:
Expand All @@ -361,6 +365,8 @@ def event_loop_normal(self):
self.update_running_batch()
if not self.running_batch:
break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
Expand Down Expand Up @@ -396,8 +402,48 @@ def event_loop_overlap(self):

self.last_batch = batch

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
)
idle_batch.prepare_for_idle()
return idle_batch

def recv_requests(self):
if self.tp_rank == 0:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []

while True:
Expand All @@ -409,7 +455,7 @@ def recv_requests(self):
else:
recv_reqs = None

if self.tp_size != 1:
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs

Expand Down Expand Up @@ -812,6 +858,10 @@ def run_batch(self, batch: ScheduleBatch):
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else:
logits_output = None
if self.skip_tokenizer_init:
Expand All @@ -830,6 +880,8 @@ def run_batch(self, batch: ScheduleBatch):
return ret

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,19 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group

def get_tp_device_group(self):
return self.model_runner.tp_group.device_group

def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
)

def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()

def get_tp_device_group(self):
return self.worker.get_tp_device_group()

def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
DECODE = auto()
# Contains both EXTEND and DECODE.
MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
IDLE = auto()

def is_prefill(self):
return self == ForwardMode.PREFILL
Expand All @@ -69,6 +71,9 @@ def is_decode(self):
def is_mixed(self):
return self == ForwardMode.MIXED

def is_idle(self):
return self == ForwardMode.IDLE


@dataclass
class ForwardBatch:
Expand Down Expand Up @@ -128,6 +133,10 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None

# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None

def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
Expand Down Expand Up @@ -209,10 +218,22 @@ def init_new(
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)

if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)

if ret.forward_mode.is_idle():
return ret

# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
)

Expand Down Expand Up @@ -592,11 +593,18 @@ def forward_extend(self, forward_batch: ForwardBatch):
get_embedding=True,
)

def forward_idle(self, forward_batch: ForwardBatch):
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)

def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")

Expand Down
Loading

0 comments on commit 976bc30

Please sign in to comment.