Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 1, 2024
1 parent 8d9d69c commit b8aef0c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.8.23"
version = "0.8.24"
authors = [
{ name="Bo PENG" },
]
Expand Down
19 changes: 14 additions & 5 deletions rwkv_pip_package/src/rwkv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def decode(self, x):
return self.tokenizer.decode(x)

def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
if temperature == 0:
temperature = 1.0
top_p = 0
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
Expand Down Expand Up @@ -109,11 +112,17 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= args.alpha_decay
if self.decode([token]) not in ' \r\n\t,.;?!"\':0123456789+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1

ttt = self.decode([token])
www = 1
if ttt in ' \t0123456789':
www = 0
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
# www = 0.5
if token not in occurrence:
occurrence[token] = www
else:
occurrence[token] += www
# print(occurrence) # debug

# output
Expand Down

0 comments on commit b8aef0c

Please sign in to comment.