Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jul 3, 2024
1 parent 95d3671 commit bad2d6c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):

if top_p > 0:
idx = torch.where(probs == cutoff)[0]
probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx)
# assert abs(torch.sum(probs).item() - top_p) < 1e-6
if len(idx) > 0:
probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx)
# assert abs(torch.sum(probs).item() - top_p) < 1e-6

if temperature != 1.0:
probs = probs ** (1.0 / temperature)
Expand Down Expand Up @@ -287,6 +288,7 @@ def decode(self, tokens):

out, state = model.forward(token, state)

torch.cuda.synchronize()
t1 = time.perf_counter()
min_time = min(min_time, t1 - t0)
min_time_all = min(min_time_all, t1 - t00)
Expand Down

0 comments on commit bad2d6c

Please sign in to comment.