diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2ff8b64d11d..efcc3a3a467 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -244,6 +244,7 @@ jobs: cd test/srt python3 test_mla.py python3 test_mla_fp8.py + python3 test_dp_attention.py - name: Evaluate data parallelism accuracy (DP=2) timeout-minutes: 10 diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 47b8d3cd56d..0c99d1ec4fa 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -28,9 +28,13 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - self.num_head = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) + + if model_runner.server_args.enable_dp_attention: + self.num_head = model_runner.model_config.num_attention_heads + else: + self.num_head = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): self.reduce_dtype = torch.float32 diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 95af3deccca..472d9174e4d 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -81,20 +81,34 @@ def __init__(self, server_args, port_args) -> None: # Start data parallel workers base_gpu_id = 0 self.workers = [] + scheduler_pipe_readers = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name - send_to = self.launch_tensor_parallel_group( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, - ) - + if server_args.enable_dp_attention: + # Share workers for DP and TP + send_to, reader = self.launch_tensor_parallel_process( + server_args, + tmp_port_args, + base_gpu_id, + dp_rank, + ) + base_gpu_id += 1 + scheduler_pipe_readers.append(reader) + else: + send_to = self.launch_tensor_parallel_group( + server_args, + tmp_port_args, + base_gpu_id, + dp_rank, + ) + base_gpu_id += server_args.tp_size self.workers.append(send_to) - base_gpu_id += server_args.tp_size + + for reader in scheduler_pipe_readers: + reader.recv() def launch_tensor_parallel_group( self, @@ -132,6 +146,27 @@ def launch_tensor_parallel_group( return send_to + def launch_tensor_parallel_process( + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, + ): + reader, writer = mp.Pipe(duplex=False) + gpu_id = base_gpu_id + tp_rank = dp_rank + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + ) + proc.start() + send_to = get_zmq_socket( + self.context, zmq.PUSH, port_args.scheduler_input_ipc_name + ) + + return send_to, reader + def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cc72663cdc7..109f3bf6f61 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -56,6 +56,7 @@ "disable_mla": ServerArgs.disable_mla, "torchao_config": ServerArgs.torchao_config, "disable_nan_detection": ServerArgs.disable_nan_detection, + "enable_dp_attention": ServerArgs.enable_dp_attention, } @@ -450,6 +451,9 @@ class ScheduleBatch: # The sum of all sequence lengths seq_lens_sum: int = None + # For DP attention + global_num_tokens: Optional[List[int]] = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -858,6 +862,16 @@ def prepare_encoder_info_decode(self): # Reset the encoder cached status self.encoder_cached = [True] * len(self.reqs) + def prepare_for_idle(self): + self.forward_mode = ForwardMode.IDLE + self.input_ids = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.seq_lens = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.extend_num_tokens = 0 + def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE @@ -969,17 +983,18 @@ def merge_batch(self, other: "ScheduleBatch"): self.has_grammar = self.has_grammar or other.has_grammar def get_model_worker_batch(self): - if self.forward_mode.is_decode(): + if self.forward_mode.is_decode() or self.forward_mode.is_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - if self.has_grammar: - self.sampling_info.grammars = [req.grammar for req in self.reqs] - else: - self.sampling_info.grammars = None + if self.sampling_info is not None: + if self.has_grammar: + self.sampling_info.grammars = [req.grammar for req in self.reqs] + else: + self.sampling_info.grammars = None global bid bid += 1 @@ -995,6 +1010,7 @@ def get_model_worker_batch(self): req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, + global_num_tokens=self.global_num_tokens, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, @@ -1051,6 +1067,9 @@ class ModelWorkerBatch: return_logprob: bool top_logprobs_nums: Optional[List[int]] + # For DP attention + global_num_tokens: Optional[List[int]] + # For extend extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7232fc2a744..2bdf4cda7c9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -110,7 +110,7 @@ def __init__( # Init inter-process communication context = zmq.Context(2) - if self.tp_rank == 0: + if self.tp_rank == 0 or self.server_args.enable_dp_attention: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name ) @@ -347,6 +347,10 @@ def event_loop_normal(self): self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() + + if self.server_args.enable_dp_attention: + batch = self.prepare_dp_attn_batch(batch) + self.cur_batch = batch if batch: @@ -361,6 +365,8 @@ def event_loop_normal(self): self.update_running_batch() if not self.running_batch: break + if self.server_args.enable_dp_attention: + batch = self.prepare_dp_attn_batch(batch) result = self.run_batch(batch) self.process_batch_result(batch, result) else: @@ -396,8 +402,48 @@ def event_loop_overlap(self): self.last_batch = batch + def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): + # Check if other DP workers have running batches + if local_batch is None: + num_tokens = 0 + elif local_batch.forward_mode.is_decode(): + num_tokens = local_batch.batch_size() + else: + num_tokens = local_batch.extend_num_tokens + + local_num_tokens = torch.tensor( + num_tokens, dtype=torch.int64, device=self.device + ) + global_num_tokens = torch.empty( + self.tp_size, dtype=torch.int64, device=self.device + ) + torch.distributed.all_gather_into_tensor( + global_num_tokens, + local_num_tokens, + group=self.tp_worker.get_tp_device_group(), + ) + + if local_batch is None and global_num_tokens.max().item() > 0: + local_batch = self.get_idle_batch() + + if local_batch is not None: + local_batch.global_num_tokens = global_num_tokens.tolist() + + return local_batch + + def get_idle_batch(self): + idle_batch = ScheduleBatch.init_new( + [], + self.req_to_token_pool, + self.token_to_kv_pool, + self.tree_cache, + self.model_config, + ) + idle_batch.prepare_for_idle() + return idle_batch + def recv_requests(self): - if self.tp_rank == 0: + if self.tp_rank == 0 or self.server_args.enable_dp_attention: recv_reqs = [] while True: @@ -409,7 +455,7 @@ def recv_requests(self): else: recv_reqs = None - if self.tp_size != 1: + if self.tp_size != 1 and not self.server_args.enable_dp_attention: recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs @@ -812,6 +858,10 @@ def run_batch(self, batch: ScheduleBatch): logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) + elif batch.forward_mode.is_idle(): + model_worker_batch = batch.get_model_worker_batch() + self.tp_worker.forward_batch_idle(model_worker_batch) + return else: logits_output = None if self.skip_tokenizer_init: @@ -830,6 +880,8 @@ def run_batch(self, batch: ScheduleBatch): return ret def process_batch_result(self, batch: ScheduleBatch, result): + if batch.forward_mode.is_idle(): + return if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) if batch.is_empty(): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8bec1a18c93..361febfac54 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -128,12 +128,19 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.model_runner.tp_group.cpu_group + def get_tp_device_group(self): + return self.model_runner.tp_group.device_group + def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, self.model_runner.token_to_kv_pool, ) + def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch): + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + self.model_runner.forward(forward_batch) + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 9200612e879..21264f1a975 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -88,6 +88,9 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.worker.get_tp_cpu_group() + def get_tp_device_group(self): + return self.worker.get_tp_device_group() + def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8bd5f197a00..3381c92117c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -56,6 +56,8 @@ class ForwardMode(IntEnum): DECODE = auto() # Contains both EXTEND and DECODE. MIXED = auto() + # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated. + IDLE = auto() def is_prefill(self): return self == ForwardMode.PREFILL @@ -69,6 +71,9 @@ def is_decode(self): def is_mixed(self): return self == ForwardMode.MIXED + def is_idle(self): + return self == ForwardMode.IDLE + @dataclass class ForwardBatch: @@ -128,6 +133,10 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None + # For DP attention + global_num_tokens: Optional[List[int]] = None + gathered_buffer: Optional[torch.Tensor] = None + def compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): @@ -209,10 +218,22 @@ def init_new( seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, + global_num_tokens=batch.global_num_tokens, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, ) + if ret.global_num_tokens is not None: + max_len = max(ret.global_num_tokens) + ret.gathered_buffer = torch.zeros( + (max_len * model_runner.tp_size, model_runner.model_config.hidden_size), + dtype=model_runner.dtype, + device=device, + ) + + if ret.forward_mode.is_idle(): + return ret + # Init position information if not ret.forward_mode.is_decode(): ret.positions = torch.concat( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 98b7cfdc761..5cde1e942ff 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -141,6 +141,7 @@ def __init__( "torchao_config": server_args.torchao_config, "disable_penalizer": server_args.disable_penalizer, "disable_nan_detection": server_args.disable_nan_detection, + "enable_dp_attention": server_args.enable_dp_attention, } ) @@ -592,11 +593,18 @@ def forward_extend(self, forward_batch: ForwardBatch): get_embedding=True, ) + def forward_idle(self, forward_batch: ForwardBatch): + return self.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): return self.forward_extend(forward_batch) + elif forward_batch.forward_mode.is_idle(): + return self.forward_idle(forward_batch) else: raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 00ba0dcc596..0540d310dc3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -22,7 +22,9 @@ from torch import nn from transformers import PretrainedConfig from vllm.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.fused_moe import FusedMoE @@ -338,6 +340,7 @@ def __init__( cache_config=None, quant_config: Optional[QuantizationConfig] = None, layer_id=None, + use_dp=False, ) -> None: super().__init__() self.layer_id = layer_id @@ -351,29 +354,80 @@ def __init__( self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + self.num_local_heads = num_heads if use_dp else num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, + if use_dp: + # For data parallel attention + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ReplicatedLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ReplicatedLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_b_proj = ReplicatedLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, + # O projection. + self.o_proj = ReplicatedLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, bias=False, quant_config=quant_config, ) else: - self.q_proj = ColumnParallelLinear( + # For tensor parallel attention + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, self.hidden_size, - self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, ) @@ -385,19 +439,6 @@ def __init__( quant_config=quant_config, ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - ) - # O projection. - self.o_proj = RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - ) rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope( qk_rope_head_dim, @@ -491,6 +532,36 @@ def forward( return output +def all_gather( + input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group +): + if world_size == 1: + return input_tensor + + all_lens = forward_batch.global_num_tokens + max_len = max(forward_batch.global_num_tokens) + + padded_tensor = torch.nn.functional.pad( + input_tensor, (0, 0, 0, max_len - input_tensor.shape[0]) + ) + + torch.distributed.all_gather_into_tensor( + forward_batch.gathered_buffer, padded_tensor, group=group + ) + + gathered_tensors = torch.concat( + [ + forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]] + for i in range(world_size) + ] + ) + + start_index = 0 if rank == 0 else sum(all_lens[:rank]) + end_index = start_index + all_lens[rank] + + return gathered_tensors, start_index, end_index + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -505,6 +576,14 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.enable_dp_attention = ( + not global_server_args_dict["disable_mla"] + and global_server_args_dict["enable_dp_attention"] + ) + if self.enable_dp_attention: + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group().device_group if not global_server_args_dict["disable_mla"]: self.self_attn = DeepseekV2AttentionMLA( config=config, @@ -523,6 +602,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, layer_id=layer_id, + use_dp=self.enable_dp_attention, ) else: self.self_attn = DeepseekV2Attention( @@ -569,20 +649,32 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if not forward_batch.forward_mode.is_idle(): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) + if self.enable_dp_attention: + hidden_states, start_idx, end_idx = all_gather( + hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group + ) + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states[start_idx:end_idx] + else: + hidden_states = self.mlp(hidden_states) + return hidden_states, residual @@ -603,6 +695,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], ) self.layers = nn.ModuleList( [ @@ -630,7 +723,8 @@ def forward( hidden_states, residual = layer( positions, hidden_states, forward_batch, residual ) - hidden_states, _ = self.norm(hidden_states, residual) + if not forward_batch.forward_mode.is_idle(): + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -646,10 +740,18 @@ def __init__( self.config = config self.quant_config = quant_config self.model = DeepseekV2Model(config, cache_config, quant_config) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) - self.logits_processor = LogitsProcessor(config) + if global_server_args_dict["enable_dp_attention"]: + self.lm_head = ReplicatedLinear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) @torch.no_grad() def forward( @@ -659,9 +761,10 @@ def forward( forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) - return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch - ) + if not forward_batch.forward_mode.is_idle(): + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5aac2c581b8..cb4d191926c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -129,6 +129,7 @@ class ServerArgs: disable_nan_detection: bool = False enable_overlap_schedule: bool = False enable_mixed_chunk: bool = False + enable_dp_attention: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int = 160 @@ -203,6 +204,16 @@ def __post_init__(self): if self.sampling_backend is None: self.sampling_backend = "flashinfer" + if self.enable_dp_attention: + self.dp_size = self.tp_size + self.chunked_prefill_size = self.chunked_prefill_size // 2 + self.disable_cuda_graph = True + self.enable_overlap_schedule = False + logger.warning( + f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " + "The CUDA graph is disabled." + ) + if self.enable_overlap_schedule: logger.warning( "Overlap scheduler mode is enabled. This is an experimental feature. " @@ -669,6 +680,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) + parser.add_argument( + "--enable-dp-attention", + action="store_true", + help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py new file mode 100644 index 00000000000..4cfdac228d5 --- /dev/null +++ b/test/srt/test_dp_attention.py @@ -0,0 +1,63 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDPAttention(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main()