Skip to content

Commit

Permalink
Simplify pytorch sampling kernel and logit processor (#2491)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 16, 2024
1 parent 8269947 commit 7a1aecb
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 107 deletions.
182 changes: 98 additions & 84 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,82 +100,9 @@ def __init__(
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)

def _get_normalized_prompt_logprobs(
self,
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)

start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs

@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()

if logits_metadata.forward_mode.is_decode():
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []

pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue

input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len

return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)

def forward(
self,
Expand All @@ -201,18 +128,20 @@ def forward(
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()

if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping)
if self.final_logit_softcapping:
last_logits.div_(self.final_logit_softcapping)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping)
last_logits.mul_(self.final_logit_softcapping)

# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitsProcessorOutput(
next_token_logits=last_logits,
)
else:
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
last_logits, logits_metadata
)

if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob:
Expand Down Expand Up @@ -248,14 +177,17 @@ def forward(
# extra logits that this padding may have produced.
all_logits = all_logits[:, : self.config.vocab_size].float()

if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping)
if self.final_logit_softcapping:
all_logits.div_(self.final_logit_softcapping)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping)
all_logits.mul_(self.final_logit_softcapping)

all_logprobs = all_logits
del all_logits, hidden_states
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

all_logprobs = self.compute_temp_top_p_normalized_logprobs(
all_logprobs, logits_metadata
)

# Get the logprob of top-k tokens
if logits_metadata.return_top_logprob:
Expand Down Expand Up @@ -309,11 +241,93 @@ def _get_logits(
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)

# Optional scaling factor, backported from vLLM 0.4
# Optional scaling factor
if self.logit_scale is not None:
logits.mul_(self.logit_scale) # In-place multiply
return logits

@staticmethod
def _get_normalized_prompt_logprobs(
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)

start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs

@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()

if logits_metadata.forward_mode.is_decode():
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []

pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue

input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len

return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)

@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor:
return torch.nn.functional.log_softmax(last_logits, dim=-1)


def test():
all_logprobs = torch.tensor(
Expand Down
32 changes: 27 additions & 5 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def forward(
# Post process logits
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits

if global_server_args_dict["sampling_backend"] == "flashinfer":
Expand Down Expand Up @@ -84,6 +83,7 @@ def forward(
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
else:
raise ValueError(
Expand All @@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
need_min_p_sampling: bool,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0

if need_min_p_sampling:
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0

sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids


def top_p_normalize_probs(
probs: torch.Tensor,
top_ps: torch.Tensor,
):
if global_server_args_dict["sampling_backend"] == "flashinfer":
return top_p_renorm_prob(probs, top_ps)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
10 changes: 3 additions & 7 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,9 +1086,9 @@ def merge_batch(self, other: "ScheduleBatch"):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.reqs.extend(other.reqs)

self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream
self.has_grammar = self.has_grammar or other.has_grammar
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar

def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
Expand All @@ -1115,7 +1115,6 @@ def get_model_worker_batch(self):
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_sum=self.seq_lens_sum,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
Expand Down Expand Up @@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
# The sum of all sequence lengths
seq_lens_sum: int

# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
Expand Down
14 changes: 8 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,19 +387,21 @@ def replay(self, forward_batch: ForwardBatch):

# Extract logprobs
if forward_batch.return_logprob:
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
next_token_logprobs = (
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
next_token_logits, logits_metadata
)
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
(
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--disable-nan-detection",
action="store_true",
help="Disable the NaN detection for better performance.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
Expand Down

0 comments on commit 7a1aecb

Please sign in to comment.