From 854d1e33bcef9ef0458de845060e8da82413d6b9 Mon Sep 17 00:00:00 2001 From: Bo Date: Sat, 11 May 2024 00:55:27 +0800 Subject: [PATCH] . --- RWKV_v6_demo_cuda_bf16.py | 66 +++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index 286be03..ad01488 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -125,7 +125,7 @@ def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0): args.vocab_size = 65536 args.head_size = 64 -context = "\nElon Musk's favorite" +context = "A few light taps upon the pane made her turn to the window. It had begun to snow again." # context = "\n北京" NUM_TRIALS = 3 LENGTH_PER_TRIAL = 100 @@ -140,26 +140,26 @@ def __init__(self, args): self.n_layer = args.n_layer self.eval() - self.w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') - w = self.w - w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') + z = self.z + z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) - keys = list(w.keys()) + keys = list(z.keys()) for k in keys: - if '.time_' in k: w[k] = w[k].squeeze() - if k.endswith('.time_decay'): w[k] = w[k].float() - if k.endswith('.time_faaaa'): w[k] = w[k].unsqueeze(-1).float() + if '.time_' in k: z[k] = z[k].squeeze() + if k.endswith('.time_decay'): z[k] = z[k].float() + if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float() for k in keys: if k.endswith('maa_w'): - w[k.replace('maa_w','maa_wkvrg')] = torch.concat([w[k],w[k.replace('maa_w','maa_k')],w[k.replace('maa_w','maa_v')],w[k.replace('maa_w','maa_r')],w[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) - del w[k] - del w[k.replace('maa_w','maa_k')] - del w[k.replace('maa_w','maa_v')] - del w[k.replace('maa_w','maa_r')] - del w[k.replace('maa_w','maa_g')] + z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) + del z[k] + del z[k.replace('maa_w','maa_k')] + del z[k.replace('maa_w','maa_v')] + del z[k.replace('maa_w','maa_r')] + del z[k.replace('maa_w','maa_g')] - self.n_head = w['blocks.0.att.time_faaaa'].shape[0] - self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head + self.n_head = z['blocks.0.att.time_faaaa'].shape[0] + self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head assert self.head_size == args.head_size @MyFunction @@ -206,32 +206,30 @@ def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, t @MyFunction def forward(self, token:int, state:List[torch.Tensor]): - with torch.no_grad(): - x = self.w['emb.weight'][token] + with torch.no_grad(): + z = self.z + x = z['emb.weight'][token] for i in range(self.n_layer): bbb = f'blocks.{i}.' att = f'blocks.{i}.att.' ffn = f'blocks.{i}.ffn.' - xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln1.weight'], bias=self.w[bbb+'ln1.bias']) - xx, x_out, s_out = self.time_mixing(xx, state[i*3+0], state[i*3+1], - self.w[att+'time_maa_x'], self.w[att+'time_maa_wkvrg'], self.w[att+'time_maa_w1'], self.w[att+'time_maa_w2'], - self.w[att+'time_decay_w1'], self.w[att+'time_decay_w2'], self.w[att+'time_faaaa'], self.w[att+'time_decay'], - self.w[att+'key.weight'], self.w[att+'value.weight'], self.w[att+'receptance.weight'], self.w[att+'gate.weight'], self.w[att+'output.weight'], - self.w[att+'ln_x.weight'], self.w[att+'ln_x.bias']) + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) + xx, state[i*3+0], state[i*3+1] = self.time_mixing(xx, state[i*3+0], state[i*3+1], + z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'], + z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'], + z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'], + z[att+'ln_x.weight'], z[att+'ln_x.bias']) x = x + xx - state[i*3+0] = x_out - state[i*3+1] = s_out - - xx = F.layer_norm(x, (self.n_embd,), weight=self.w[bbb+'ln2.weight'], bias=self.w[bbb+'ln2.bias']) - xx, x_out = self.channel_mixing(xx, state[i*3+2], - self.w[ffn+'time_maa_k'], self.w[ffn+'time_maa_r'], - self.w[ffn+'key.weight'], self.w[ffn+'value.weight'], self.w[ffn+'receptance.weight']) + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) + xx, state[i*3+2] = self.channel_mixing(xx, state[i*3+2], + z[ffn+'time_maa_k'], z[ffn+'time_maa_r'], + z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight']) x = x + xx - state[i*3+2] = x_out - x = F.layer_norm(x, (self.n_embd,), weight=self.w['ln_out.weight'], bias=self.w['ln_out.bias']) - x = self.w['head.weight'] @ x + x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) + x = z['head.weight'] @ x return x, state print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')