Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 9, 2024
1 parent 686a78e commit 89bd9ad
Showing 1 changed file with 44 additions and 39 deletions.
83 changes: 44 additions & 39 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method

########################################################################################################

class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
Expand Down Expand Up @@ -97,7 +99,7 @@ def printTokens(self, tokens):
########################################################################################################

def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).cpu().numpy()
probs = F.softmax(out.float(), dim=-1).cpu().numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
Expand Down Expand Up @@ -159,89 +161,92 @@ def __init__(self, args):
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])

def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

@MyFunction
def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
sx = x_prev - x
xk = x + sx * time_maa_k
xr = x + sx * time_maa_r
r = torch.sigmoid(rw @ xr)
k = torch.relu(kw @ xk) ** 2
k = x + sx * time_maa_k
r = x + sx * time_maa_r

r = torch.sigmoid(rw @ r)
k = torch.relu(kw @ k) ** 2

return r * (vw @ k), x

@MyFunction
def time_mixing(self, x, x_prev, s, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
def time_mixing(self, x, x_prev, state, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
N = self.head_size

sx = x_prev - x
xxx = x + sx * x_maa
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
xxx = torch.bmm(xxx, tm_w2).view(5, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
w, k, v, r, g = xxx.unbind(dim=0)

xw = x + sx * (w_maa + mw)
xk = x + sx * (k_maa + mk)
xv = x + sx * (v_maa + mv)
xr = x + sx * (r_maa + mr)
xg = x + sx * (g_maa + mg)
w = x + sx * (w_maa + w)
k = x + sx * (k_maa + k)
v = x + sx * (v_maa + v)
r = x + sx * (r_maa + r)
g = x + sx * (g_maa + g)

r = (rw @ xr).view(H, 1, N)
k = (kw @ xk).view(H, N, 1)
v = (vw @ xv).view(H, 1, N)
g = F.silu(gw @ xg)
w = torch.tanh(w @ td_w1) @ td_w2
w = w.float() + time_decay
# assert w.dtype == torch.float
w = torch.exp(-torch.exp(w))

a = (k @ v).float()
out = r @ (time_faaaa * a + s).to(torch.bfloat16)
k = (kw @ k).view(H, N, 1)
v = (vw @ v).view(H, 1, N)
r = (rw @ r).view(H, 1, N)
g = F.silu(gw @ g)

w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, N, 1)
assert w.dtype == torch.float
w = torch.exp(-torch.exp(w))
s = a + w * s
kv = (k @ v).float()
out = r @ (time_faaaa * kv + state).to(torch.bfloat16)

state = kv + w.view(H, N, 1) * state

out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) * g # same as gn(x/8, eps=1e-5)
return ow @ out, x, s
out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
return ow @ (out * g), x, state

def forward(self, token, state):
with torch.no_grad():
if state == None:
state = [None for _ in range(args.n_layer * 3)]
for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
state[i*3+0] = torch.zeros(self.args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
state[i*3+1] = torch.zeros((self.n_head, self.head_size, self.head_size), dtype=torch.float, requires_grad=False, device="cuda")
state[i*3+2] = torch.zeros(self.args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")

x = self.w.emb.weight[token]
for i in range(self.args.n_layer):

att = self.w.blocks[i].att
xx, saa, ss = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1],
xx, x_out, s_out = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1],
att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
x = x + xx
state[i*3+0] = saa
state[i*3+1] = ss
state[i*3+0] = x_out
state[i*3+1] = s_out

ffn = self.w.blocks[i].ffn
xx, ss = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state[i*3+2],
xx, x_out = self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state[i*3+2],
ffn.time_maa_k, ffn.time_maa_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = x + xx
state[i*3+2] = ss
state[i*3+2] = x_out

x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
return x, state

print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None

init_state = [None for _ in range(args.n_layer * 3)]
for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda")
init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")

for token in tokenizer.encode(context):
init_out, init_state = model.forward(token, init_state)

Expand Down

0 comments on commit 89bd9ad

Please sign in to comment.