Skip to content

Commit

Permalink
always cast probs_idx to int32 in pytorch sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
YangQun1 committed Dec 5, 2024
1 parent 79546a8 commit 0ee1f26
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
sampled_index = torch.multinomial(probs_sort, num_samples=1)
if probs_idx.device.type == "hpu":
# HPU gather don't support int64 tensors
probs_idx = probs_idx.to(torch.int32)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids

0 comments on commit 0ee1f26

Please sign in to comment.