diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index 7374962..ace39a4 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -18,6 +18,8 @@ MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method +######################################################################################################## + class RWKV_TOKENIZER(): table: list[list[list[bytes]]] good: list[set[int]] @@ -97,7 +99,7 @@ def printTokens(self, tokens): ######################################################################################################## def sample_logits(out, temperature=1.0, top_p=0.8): - probs = F.softmax(out, dim=-1).cpu().numpy() + probs = F.softmax(out.float(), dim=-1).cpu().numpy() sorted_probs = np.sort(probs)[::-1] cumulative_probs = np.cumsum(sorted_probs) cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) @@ -159,21 +161,23 @@ def __init__(self, args): if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) here = getattr(here, p) setattr(here, last, w[k]) - + def layer_norm(self, x, w): return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) @MyFunction def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): sx = x_prev - x - xk = x + sx * time_maa_k - xr = x + sx * time_maa_r - r = torch.sigmoid(rw @ xr) - k = torch.relu(kw @ xk) ** 2 + k = x + sx * time_maa_k + r = x + sx * time_maa_r + + r = torch.sigmoid(rw @ r) + k = torch.relu(kw @ k) ** 2 + return r * (vw @ k), x @MyFunction - def time_mixing(self, x, x_prev, s, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): + def time_mixing(self, x, x_prev, state, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): H = self.n_head N = self.head_size @@ -181,67 +185,68 @@ def time_mixing(self, x, x_prev, s, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm xxx = x + sx * x_maa xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) xxx = torch.bmm(xxx, tm_w2).view(5, -1) - mw, mk, mv, mr, mg = xxx.unbind(dim=0) + w, k, v, r, g = xxx.unbind(dim=0) - xw = x + sx * (w_maa + mw) - xk = x + sx * (k_maa + mk) - xv = x + sx * (v_maa + mv) - xr = x + sx * (r_maa + mr) - xg = x + sx * (g_maa + mg) + w = x + sx * (w_maa + w) + k = x + sx * (k_maa + k) + v = x + sx * (v_maa + v) + r = x + sx * (r_maa + r) + g = x + sx * (g_maa + g) - r = (rw @ xr).view(H, 1, N) - k = (kw @ xk).view(H, N, 1) - v = (vw @ xv).view(H, 1, N) - g = F.silu(gw @ xg) + w = torch.tanh(w @ td_w1) @ td_w2 + w = w.float() + time_decay + # assert w.dtype == torch.float + w = torch.exp(-torch.exp(w)) - a = (k @ v).float() - out = r @ (time_faaaa * a + s).to(torch.bfloat16) + k = (kw @ k).view(H, N, 1) + v = (vw @ v).view(H, 1, N) + r = (rw @ r).view(H, 1, N) + g = F.silu(gw @ g) - w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, N, 1) - assert w.dtype == torch.float - w = torch.exp(-torch.exp(w)) - s = a + w * s + kv = (k @ v).float() + out = r @ (time_faaaa * kv + state).to(torch.bfloat16) + + state = kv + w.view(H, N, 1) * state - out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) * g # same as gn(x/8, eps=1e-5) - return ow @ out, x, s + out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5) + return ow @ (out * g), x, state def forward(self, token, state): with torch.no_grad(): - if state == None: - state = [None for _ in range(args.n_layer * 3)] - for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx - state[i*3+0] = torch.zeros(self.args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") - state[i*3+1] = torch.zeros((self.n_head, self.head_size, self.head_size), dtype=torch.float, requires_grad=False, device="cuda") - state[i*3+2] = torch.zeros(self.args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") - x = self.w.emb.weight[token] for i in range(self.args.n_layer): att = self.w.blocks[i].att - xx, saa, ss = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1], + xx, x_out, s_out = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1], att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2, att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay, att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight, att.ln_x.weight, att.ln_x.bias) x = x + xx - state[i*3+0] = saa - state[i*3+1] = ss + state[i*3+0] = x_out + state[i*3+1] = s_out ffn = self.w.blocks[i].ffn - xx, ss = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state[i*3+2], + xx, x_out = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state[i*3+2], ffn.time_maa_k, ffn.time_maa_r, ffn.key.weight, ffn.value.weight, ffn.receptance.weight) x = x + xx - state[i*3+2] = ss + state[i*3+2] = x_out x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out) - return x.float(), state + return x, state print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...') model = RWKV_RNN(args) print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') -init_state = None + +init_state = [None for _ in range(args.n_layer * 3)] +for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev + init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda") + init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + for token in tokenizer.encode(context): init_out, init_state = model.forward(token, init_state)