diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index fd8c5c8..0793a60 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -85,30 +85,31 @@ def encode(self, src: str): def decode(self, tokens): return self.decodeBytes(tokens).decode('utf-8') - def printTokens(self, tokens): - for i in tokens: - s = self.idx2token[i] - try: - s = s.decode('utf-8') - except: - pass - print(f'{repr(s)}{i}', end=' ') - # print(repr(s), i) - print() - ######################################################################################################## -def sample_logits(out, temperature=1.0, top_p=0.8): - probs = F.softmax(out.float(), dim=-1).cpu().numpy() - sorted_probs = np.sort(probs)[::-1] - cumulative_probs = np.cumsum(sorted_probs) - cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) - probs[probs < cutoff] = 0 +@torch.jit.script +def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0): + probs = F.softmax(logits.float(), dim=-1) + sorted_probs, sorted_ids = torch.sort(probs, descending=True) + + if top_k > 0: + probs[sorted_ids[top_k:]] = 0 + + if top_p < 1: + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + cutoff_index = torch.searchsorted(cumulative_probs, top_p) + cutoff = sorted_probs[cutoff_index] + probs[probs < cutoff] = 0 + + if top_p > 0: + idx = torch.where(probs == cutoff)[0] + probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx) + # assert abs(torch.sum(probs).item() - top_p) < 1e-6 + if temperature != 1.0: probs = probs ** (1.0 / temperature) - probs = probs / np.sum(probs) - out = np.random.choice(a=len(probs), p=probs) - return out + + return torch.multinomial(probs, num_samples=1).item() ######################################################################################################## @@ -139,11 +140,20 @@ def __init__(self, args): w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) - for k in w.keys(): - if '.time_' in k: w[k] = w[k].squeeze() + keys = list(w.keys()) + for k in keys: + if '.time_' in k: w[k] = w[k].squeeze() if k.endswith('.time_decay'): w[k] = w[k].float() if k.endswith('.time_faaaa'): w[k] = w[k].unsqueeze(-1).float() - + for k in keys: + if k.endswith('maa_w'): + w[k.replace('maa_w','maa_wkvrg')] = torch.concat([w[k],w[k.replace('maa_w','maa_k')],w[k.replace('maa_w','maa_v')],w[k.replace('maa_w','maa_r')],w[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) + del w[k] + del w[k.replace('maa_w','maa_k')] + del w[k.replace('maa_w','maa_v')] + del w[k.replace('maa_w','maa_r')] + del w[k.replace('maa_w','maa_g')] + self.n_head = w['blocks.0.att.time_faaaa'].shape[0] self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head assert self.head_size == args.head_size @@ -179,22 +189,18 @@ def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): return r * (vw @ k), x @MyFunction - def time_mixing(self, x, x_prev, state, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): + def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): H = self.n_head N = self.head_size sx = x_prev - x - xxx = x + sx * x_maa - xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) - xxx = torch.bmm(xxx, tm_w2).view(5, -1) + xxx = x + sx * maa_x # C + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L + xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C + xxx = xxx + maa_wkvrg + xxx = xxx * sx.expand(5, -1) + x.expand(5, -1) w, k, v, r, g = xxx.unbind(dim=0) - w = x + sx * (w_maa + w) - k = x + sx * (k_maa + k) - v = x + sx * (v_maa + v) - r = x + sx * (r_maa + r) - g = x + sx * (g_maa + g) - w = torch.tanh(w @ td_w1) @ td_w2 w = w.float() + time_decay # assert w.dtype == torch.float @@ -220,7 +226,7 @@ def forward(self, token, state): att = self.w.blocks[i].att xx, x_out, s_out = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1], - att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2, + att.time_maa_x, att.time_maa_wkvrg, att.time_maa_w1, att.time_maa_w2, att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay, att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight, att.ln_x.weight, att.ln_x.bias) @@ -259,8 +265,10 @@ def forward(self, token, state): out, state = init_out.clone(), copy.deepcopy(init_state) min_time = 1e10 + min_time_all = 1e10 for i in range(LENGTH_PER_TRIAL): + t00 = time.time() token = sample_logits(out, TEMPERATURE, TOP_P) all_tokens += [token] try: @@ -276,7 +284,8 @@ def forward(self, token, state): t1 = time.time() min_time = min(min_time, t1 - t0) + min_time_all = min(min_time_all, t1 - t00) - print(f'\n[{round(1/min_time,2)} token/s (ignore tokenizer & sampling)]', end='') + print(f'\n[{round(1/min_time_all,2)} (real) / {round(1/min_time,2)} token/s (ignore tokenizer & sampling)]', end='') print('\n')