Skip to content

Commit

Permalink
works for win10
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 12, 2024
1 parent 843f8e5 commit 06aa1d4
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script

########################################################################################################

args = types.SimpleNamespace()
args.tokenizer = "tokenizer/rwkv_vocab_v20230424.txt"
args.MODEL_NAME = '/home/rwkv/rwkv-final-v6-2.1-3b'
# args.MODEL_NAME = '/home/rwkv/rwkv-final-v6-2.1-3b'
# args.MODEL_NAME = '/mnt/program/rwkv-final-v6-2.1-3b'
args.MODEL_NAME = 'E:/RWKV-Runner/models/rwkv-final-v6-2.1-3b'
args.n_layer = 32
args.n_embd = 2560
args.vocab_size = 65536
Expand Down Expand Up @@ -125,7 +127,10 @@ def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2

out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
return ow @ (out * g), x, state
time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
try:
time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
except:
time_mixing = torch.jit.script(time_mixing__)

########################################################################################################

Expand All @@ -138,11 +143,14 @@ def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
k = torch.relu(kw @ k) ** 2

return r * (vw @ k), x
channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
try:
channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
except:
channel_mixing = torch.jit.script(channel_mixing__)

########################################################################################################

@torch.jit.script
@MyStatic
def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):
probs = F.softmax(logits.float(), dim=-1)
sorted_probs, sorted_ids = torch.sort(probs, descending=True)
Expand Down

0 comments on commit 06aa1d4

Please sign in to comment.