Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 10, 2024
1 parent af821db commit 854d1e3
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):
args.vocab_size = 65536
args.head_size = 64

context = "\nElon Musk's favorite"
context = "A few light taps upon the pane made her turn to the window. It had begun to snow again."
# context = "\n北京"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
Expand All @@ -140,26 +140,26 @@ def __init__(self, args):
self.n_layer = args.n_layer
self.eval()

self.w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
w = self.w
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
z = self.z
z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias'])

keys = list(w.keys())
keys = list(z.keys())
for k in keys:
if '.time_' in k: w[k] = w[k].squeeze()
if k.endswith('.time_decay'): w[k] = w[k].float()
if k.endswith('.time_faaaa'): w[k] = w[k].unsqueeze(-1).float()
if '.time_' in k: z[k] = z[k].squeeze()
if k.endswith('.time_decay'): z[k] = z[k].float()
if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float()
for k in keys:
if k.endswith('maa_w'):
w[k.replace('maa_w','maa_wkvrg')] = torch.concat([w[k],w[k.replace('maa_w','maa_k')],w[k.replace('maa_w','maa_v')],w[k.replace('maa_w','maa_r')],w[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1)
del w[k]
del w[k.replace('maa_w','maa_k')]
del w[k.replace('maa_w','maa_v')]
del w[k.replace('maa_w','maa_r')]
del w[k.replace('maa_w','maa_g')]
z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1)
del z[k]
del z[k.replace('maa_w','maa_k')]
del z[k.replace('maa_w','maa_v')]
del z[k.replace('maa_w','maa_r')]
del z[k.replace('maa_w','maa_g')]

self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
self.n_head = z['blocks.0.att.time_faaaa'].shape[0]
self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head
assert self.head_size == args.head_size

@MyFunction
Expand Down Expand Up @@ -206,32 +206,30 @@ def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, t

@MyFunction
def forward(self, token:int, state:List[torch.Tensor]):
with torch.no_grad():
x = self.w['emb.weight'][token]
with torch.no_grad():
z = self.z
x = z['emb.weight'][token]
for i in range(self.n_layer):
bbb = f'blocks.{i}.'
att = f'blocks.{i}.att.'
ffn = f'blocks.{i}.ffn.'

xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln1.weight'], bias=self.w[bbb+'ln1.bias'])
xx, x_out, s_out = self.time_mixing(xx, state[i*3+0], state[i*3+1],
self.w[att+'time_maa_x'], self.w[att+'time_maa_wkvrg'], self.w[att+'time_maa_w1'], self.w[att+'time_maa_w2'],
self.w[att+'time_decay_w1'], self.w[att+'time_decay_w2'], self.w[att+'time_faaaa'], self.w[att+'time_decay'],
self.w[att+'key.weight'], self.w[att+'value.weight'], self.w[att+'receptance.weight'], self.w[att+'gate.weight'], self.w[att+'output.weight'],
self.w[att+'ln_x.weight'], self.w[att+'ln_x.bias'])
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
xx, state[i*3+0], state[i*3+1] = self.time_mixing(xx, state[i*3+0], state[i*3+1],
z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'],
z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'],
z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'],
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
x = x + xx
state[i*3+0] = x_out
state[i*3+1] = s_out

xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln2.weight'], bias=self.w[bbb+'ln2.bias'])
xx, x_out = self.channel_mixing(xx, state[i*3+2],
self.w[ffn+'time_maa_k'], self.w[ffn+'time_maa_r'],
self.w[ffn+'key.weight'], self.w[ffn+'value.weight'], self.w[ffn+'receptance.weight'])

xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
xx, state[i*3+2] = self.channel_mixing(xx, state[i*3+2],
z[ffn+'time_maa_k'], z[ffn+'time_maa_r'],
z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight'])
x = x + xx
state[i*3+2] = x_out

x = F.layer_norm(x, (self.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias'])
x = self.w['head.weight'] @ x
x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
x = z['head.weight'] @ x
return x, state

print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
Expand Down

0 comments on commit 854d1e3

Please sign in to comment.