From 3d65c528df673615868cadcf36a7b27de7c95635 Mon Sep 17 00:00:00 2001 From: Bo Date: Fri, 10 May 2024 05:45:44 +0800 Subject: [PATCH] . --- RWKV_v6_demo_cuda_bf16.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index ace39a4..fd8c5c8 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -121,6 +121,7 @@ def sample_logits(out, temperature=1.0, top_p=0.8): args.n_layer = 32 args.n_embd = 2560 args.vocab_size = 65536 +args.head_size = 64 context = "\nElon Musk's favorite" # context = "\n北京" @@ -145,6 +146,7 @@ def __init__(self, args): self.n_head = w['blocks.0.att.time_faaaa'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head + assert self.head_size == args.head_size self.w = types.SimpleNamespace() # set self.w from w self.w.blocks = {}