From 32c3ed803500930e5f39a5941bb021d346dccf23 Mon Sep 17 00:00:00 2001 From: qunyang Date: Thu, 5 Dec 2024 08:01:45 +0200 Subject: [PATCH] always cast probs_idx to int32 in pytorch sampler --- python/sglang/srt/layers/sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 4f05f013f8a..b0dfda3e882 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -111,8 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) sampled_index = torch.multinomial(probs_sort, num_samples=1) - if probs_idx.device.type == "hpu": - # HPU gather don't support int64 tensors - probs_idx = probs_idx.to(torch.int32) + # int32 range is enough to represent the token ids + probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids