Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 3, 2024
1 parent 4caa6e9 commit 3ae18bb
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions API_DEMO_CHAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,18 @@ def run_rnn(ctx):

for i in range(99999):
for n in occurrence:
out[n] -= 0 + occurrence[n] * 1.0 # repetition penalty
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # repetition penalty
out[0] -= 1e10 # disable END_OF_TEXT

token = pipeline.sample_logits(out, temperature=1.0, top_p=0.3)
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)

out, model_state = model.forward([token], model_state)
model_tokens += [token]

out_tokens += [token]

for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)

tmp = pipeline.decode(out_tokens[out_last:])
Expand Down

0 comments on commit 3ae18bb

Please sign in to comment.