Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DP MLA #1970

Merged
merged 16 commits into from
Nov 16, 2024
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ jobs:
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py

- name: Evaluate data parallelism accuracy (DP=2)
timeout-minutes: 10
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def __init__(self, model_runner: ModelRunner):

self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

if model_runner.server_args.enable_dp_attention:
self.num_head = model_runner.model_config.num_attention_heads
else:
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
Expand Down
51 changes: 43 additions & 8 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,34 @@ def __init__(self, server_args, port_args) -> None:
# Start data parallel workers
base_gpu_id = 0
self.workers = []
scheduler_pipe_readers = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name

send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)

if server_args.enable_dp_attention:
# Share workers for DP and TP
send_to, reader = self.launch_tensor_parallel_process(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += 1
scheduler_pipe_readers.append(reader)
else:
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += server_args.tp_size
self.workers.append(send_to)
base_gpu_id += server_args.tp_size

for reader in scheduler_pipe_readers:
reader.recv()

def launch_tensor_parallel_group(
self,
Expand Down Expand Up @@ -132,6 +146,27 @@ def launch_tensor_parallel_group(

return send_to

def launch_tensor_parallel_process(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id
tp_rank = dp_rank
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)

return send_to, reader

def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
Expand Down
29 changes: 24 additions & 5 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config,
"disable_nan_detection": ServerArgs.disable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
}


Expand Down Expand Up @@ -450,6 +451,9 @@ class ScheduleBatch:
# The sum of all sequence lengths
seq_lens_sum: int = None

# For DP attention
global_num_tokens: Optional[List[int]] = None

# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
Expand Down Expand Up @@ -858,6 +862,16 @@ def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)

def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = 0

def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE

Expand Down Expand Up @@ -969,17 +983,18 @@ def merge_batch(self, other: "ScheduleBatch"):
self.has_grammar = self.has_grammar or other.has_grammar

def get_model_worker_batch(self):
if self.forward_mode.is_decode():
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens

if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
if self.sampling_info is not None:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None

global bid
bid += 1
Expand All @@ -995,6 +1010,7 @@ def get_model_worker_batch(self):
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
Expand Down Expand Up @@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
return_logprob: bool
top_logprobs_nums: Optional[List[int]]

# For DP attention
global_num_tokens: Optional[List[int]]

# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
Expand Down
58 changes: 55 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
# Init inter-process communication
context = zmq.Context(2)

if self.tp_rank == 0:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name
)
Expand Down Expand Up @@ -347,6 +347,10 @@ def event_loop_normal(self):
self.process_input_requests(recv_reqs)

batch = self.get_next_batch_to_run()

if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)

self.cur_batch = batch

if batch:
Expand All @@ -361,6 +365,8 @@ def event_loop_normal(self):
self.update_running_batch()
if not self.running_batch:
break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
Expand Down Expand Up @@ -396,8 +402,48 @@ def event_loop_overlap(self):

self.last_batch = batch

def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
)

if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
)
idle_batch.prepare_for_idle()
return idle_batch

def recv_requests(self):
if self.tp_rank == 0:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []

while True:
Expand All @@ -409,7 +455,7 @@ def recv_requests(self):
else:
recv_reqs = None

if self.tp_size != 1:
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs

Expand Down Expand Up @@ -812,6 +858,10 @@ def run_batch(self, batch: ScheduleBatch):
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else:
logits_output = None
if self.skip_tokenizer_init:
Expand All @@ -830,6 +880,8 @@ def run_batch(self, batch: ScheduleBatch):
return ret

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,19 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group

def get_tp_device_group(self):
return self.model_runner.tp_group.device_group

def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
)

def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()

def get_tp_device_group(self):
return self.worker.get_tp_device_group()

def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
DECODE = auto()
# Contains both EXTEND and DECODE.
MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
IDLE = auto()

def is_prefill(self):
return self == ForwardMode.PREFILL
Expand All @@ -69,6 +71,9 @@ def is_decode(self):
def is_mixed(self):
return self == ForwardMode.MIXED

def is_idle(self):
return self == ForwardMode.IDLE


@dataclass
class ForwardBatch:
Expand Down Expand Up @@ -128,6 +133,10 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None

# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None

def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
Expand Down Expand Up @@ -209,10 +218,22 @@ def init_new(
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)

if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)

if ret.forward_mode.is_idle():
return ret

# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
"torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
}
)

Expand Down Expand Up @@ -592,11 +593,18 @@ def forward_extend(self, forward_batch: ForwardBatch):
get_embedding=True,
)

def forward_idle(self, forward_batch: ForwardBatch):
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)

def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")

Expand Down
Loading
Loading