Skip to content

Commit

Permalink
faster
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 10, 2024
1 parent fb54f8c commit af821db
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch, copy, time
from typing import List
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -135,9 +136,12 @@ class RWKV_RNN(MyModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
self.n_embd = args.n_embd
self.n_layer = args.n_layer
self.eval()

w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
self.w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
w = self.w
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])

keys = list(w.keys())
Expand All @@ -158,25 +162,6 @@ def __init__(self, args):
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
assert self.head_size == args.head_size

self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_decay" => self.w.blocks[0].att.time_decay
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])

def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

@MyFunction
def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
sx = x_prev - x
Expand Down Expand Up @@ -219,29 +204,34 @@ def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, t
out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
return ow @ (out * g), x, state

def forward(self, token, state):
@MyFunction
def forward(self, token:int, state:List[torch.Tensor]):
with torch.no_grad():
x = self.w.emb.weight[token]
for i in range(self.args.n_layer):

att = self.w.blocks[i].att
xx, x_out, s_out = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1],
att.time_maa_x, att.time_maa_wkvrg, att.time_maa_w1, att.time_maa_w2,
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
x = self.w['emb.weight'][token]
for i in range(self.n_layer):
bbb = f'blocks.{i}.'
att = f'blocks.{i}.att.'
ffn = f'blocks.{i}.ffn.'

xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln1.weight'], bias=self.w[bbb+'ln1.bias'])
xx, x_out, s_out = self.time_mixing(xx, state[i*3+0], state[i*3+1],
self.w[att+'time_maa_x'], self.w[att+'time_maa_wkvrg'], self.w[att+'time_maa_w1'], self.w[att+'time_maa_w2'],
self.w[att+'time_decay_w1'], self.w[att+'time_decay_w2'], self.w[att+'time_faaaa'], self.w[att+'time_decay'],
self.w[att+'key.weight'], self.w[att+'value.weight'], self.w[att+'receptance.weight'], self.w[att+'gate.weight'], self.w[att+'output.weight'],
self.w[att+'ln_x.weight'], self.w[att+'ln_x.bias'])
x = x + xx
state[i*3+0] = x_out
state[i*3+1] = s_out

ffn = self.w.blocks[i].ffn
xx, x_out = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state[i*3+2],
ffn.time_maa_k, ffn.time_maa_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln2.weight'], bias=self.w[bbb+'ln2.bias'])
xx, x_out = self.channel_mixing(xx, state[i*3+2],
self.w[ffn+'time_maa_k'], self.w[ffn+'time_maa_r'],
self.w[ffn+'key.weight'], self.w[ffn+'value.weight'], self.w[ffn+'receptance.weight'])
x = x + xx
state[i*3+2] = x_out

x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
x = F.layer_norm(x, (self.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])
x = self.w['head.weight'] @ x
return x, state

print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
Expand Down

0 comments on commit af821db

Please sign in to comment.