From 06aa1d472ba95434af11a8eca737815257dabb73 Mon Sep 17 00:00:00 2001 From: Bo Date: Mon, 13 May 2024 00:32:29 +0800 Subject: [PATCH] works for win10 --- RWKV_v6_demo_cuda_bf16.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index 0537c7e..75c3d5c 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -18,13 +18,15 @@ MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method +MyStatic = torch.jit.script ######################################################################################################## args = types.SimpleNamespace() args.tokenizer = "tokenizer/rwkv_vocab_v20230424.txt" -args.MODEL_NAME = '/home/rwkv/rwkv-final-v6-2.1-3b' +# args.MODEL_NAME = '/home/rwkv/rwkv-final-v6-2.1-3b' # args.MODEL_NAME = '/mnt/program/rwkv-final-v6-2.1-3b' +args.MODEL_NAME = 'E:/RWKV-Runner/models/rwkv-final-v6-2.1-3b' args.n_layer = 32 args.n_embd = 2560 args.vocab_size = 65536 @@ -125,7 +127,10 @@ def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2 out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5) return ow @ (out * g), x, state -time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +try: + time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + time_mixing = torch.jit.script(time_mixing__) ######################################################################################################## @@ -138,11 +143,14 @@ def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): k = torch.relu(kw @ k) ** 2 return r * (vw @ k), x -channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +try: + channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + channel_mixing = torch.jit.script(channel_mixing__) ######################################################################################################## -@torch.jit.script +@MyStatic def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0): probs = F.softmax(logits.float(), dim=-1) sorted_probs, sorted_ids = torch.sort(probs, descending=True)