Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify pytorch sampling kernel and logit processor #2491

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 98 additions & 84 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,82 +100,9 @@ def __init__(
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)

def _get_normalized_prompt_logprobs(
self,
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)

start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs

@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()

if logits_metadata.forward_mode.is_decode():
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []

pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue

input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len

return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)

def forward(
self,
Expand All @@ -201,18 +128,20 @@ def forward(
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()

if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping)
if self.final_logit_softcapping:
last_logits.div_(self.final_logit_softcapping)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping)
last_logits.mul_(self.final_logit_softcapping)

# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitsProcessorOutput(
next_token_logits=last_logits,
)
else:
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
last_logits, logits_metadata
)

if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob:
Expand Down Expand Up @@ -248,14 +177,17 @@ def forward(
# extra logits that this padding may have produced.
all_logits = all_logits[:, : self.config.vocab_size].float()

if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping)
if self.final_logit_softcapping:
all_logits.div_(self.final_logit_softcapping)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping)
all_logits.mul_(self.final_logit_softcapping)

all_logprobs = all_logits
del all_logits, hidden_states
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

all_logprobs = self.compute_temp_top_p_normalized_logprobs(
all_logprobs, logits_metadata
)

# Get the logprob of top-k tokens
if logits_metadata.return_top_logprob:
Expand Down Expand Up @@ -309,11 +241,93 @@ def _get_logits(
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)

# Optional scaling factor, backported from vLLM 0.4
# Optional scaling factor
if self.logit_scale is not None:
logits.mul_(self.logit_scale) # In-place multiply
return logits

@staticmethod
def _get_normalized_prompt_logprobs(
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)

start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs

@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()

if logits_metadata.forward_mode.is_decode():
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []

pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue

input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len

return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)

@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor:
return torch.nn.functional.log_softmax(last_logits, dim=-1)


def test():
all_logprobs = torch.tensor(
Expand Down
32 changes: 27 additions & 5 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def forward(
# Post process logits
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits

if global_server_args_dict["sampling_backend"] == "flashinfer":
Expand Down Expand Up @@ -84,6 +83,7 @@ def forward(
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
else:
raise ValueError(
Expand All @@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
need_min_p_sampling: bool,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0

if need_min_p_sampling:
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0

sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids


def top_p_normalize_probs(
probs: torch.Tensor,
top_ps: torch.Tensor,
):
if global_server_args_dict["sampling_backend"] == "flashinfer":
return top_p_renorm_prob(probs, top_ps)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
10 changes: 3 additions & 7 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,9 +1086,9 @@ def merge_batch(self, other: "ScheduleBatch"):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.reqs.extend(other.reqs)

self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream
self.has_grammar = self.has_grammar or other.has_grammar
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream
self.has_grammar |= other.has_grammar

def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
Expand All @@ -1115,7 +1115,6 @@ def get_model_worker_batch(self):
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_sum=self.seq_lens_sum,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
Expand Down Expand Up @@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
# The sum of all sequence lengths
seq_lens_sum: int

# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
Expand Down
14 changes: 8 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,19 +387,21 @@ def replay(self, forward_batch: ForwardBatch):

# Extract logprobs
if forward_batch.return_logprob:
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
next_token_logprobs = (
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
next_token_logits, logits_metadata
)
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
(
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--disable-nan-detection",
action="store_true",
help="Disable the NaN detection for better performance.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
Expand Down
Loading