diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 4f05f013f8a..b0dfda3e882 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -111,8 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) sampled_index = torch.multinomial(probs_sort, num_samples=1) - if probs_idx.device.type == "hpu": - # HPU gather don't support int64 tensors - probs_idx = probs_idx.to(torch.int32) + # int32 range is enough to represent the token ids + probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids