Skip to content

Commit

Permalink
faster
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 10, 2024
1 parent 3d65c52 commit fb54f8c
Showing 1 changed file with 44 additions and 35 deletions.
79 changes: 44 additions & 35 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,31 @@ def encode(self, src: str):
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')

def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
print()

########################################################################################################

def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out.float(), dim=-1).cpu().numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
@torch.jit.script
def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):
probs = F.softmax(logits.float(), dim=-1)
sorted_probs, sorted_ids = torch.sort(probs, descending=True)

if top_k > 0:
probs[sorted_ids[top_k:]] = 0

if top_p < 1:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_index = torch.searchsorted(cumulative_probs, top_p)
cutoff = sorted_probs[cutoff_index]
probs[probs < cutoff] = 0

if top_p > 0:
idx = torch.where(probs == cutoff)[0]
probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx)
# assert abs(torch.sum(probs).item() - top_p) < 1e-6

if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out

return torch.multinomial(probs, num_samples=1).item()

########################################################################################################

Expand Down Expand Up @@ -139,11 +140,20 @@ def __init__(self, args):
w = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])

for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
keys = list(w.keys())
for k in keys:
if '.time_' in k: w[k] = w[k].squeeze()
if k.endswith('.time_decay'): w[k] = w[k].float()
if k.endswith('.time_faaaa'): w[k] = w[k].unsqueeze(-1).float()

for k in keys:
if k.endswith('maa_w'):
w[k.replace('maa_w','maa_wkvrg')] = torch.concat([w[k],w[k.replace('maa_w','maa_k')],w[k.replace('maa_w','maa_v')],w[k.replace('maa_w','maa_r')],w[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1)
del w[k]
del w[k.replace('maa_w','maa_k')]
del w[k.replace('maa_w','maa_v')]
del w[k.replace('maa_w','maa_r')]
del w[k.replace('maa_w','maa_g')]

self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
assert self.head_size == args.head_size
Expand Down Expand Up @@ -179,22 +189,18 @@ def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
return r * (vw @ k), x

@MyFunction
def time_mixing(self, x, x_prev, state, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
N = self.head_size

sx = x_prev - x
xxx = x + sx * x_maa
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
xxx = torch.bmm(xxx, tm_w2).view(5, -1)
xxx = x + sx * maa_x # C
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L
xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C
xxx = xxx + maa_wkvrg
xxx = xxx * sx.expand(5, -1) + x.expand(5, -1)
w, k, v, r, g = xxx.unbind(dim=0)

w = x + sx * (w_maa + w)
k = x + sx * (k_maa + k)
v = x + sx * (v_maa + v)
r = x + sx * (r_maa + r)
g = x + sx * (g_maa + g)

w = torch.tanh(w @ td_w1) @ td_w2
w = w.float() + time_decay
# assert w.dtype == torch.float
Expand All @@ -220,7 +226,7 @@ def forward(self, token, state):

att = self.w.blocks[i].att
xx, x_out, s_out = self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state[i*3+0], state[i*3+1],
att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
att.time_maa_x, att.time_maa_wkvrg, att.time_maa_w1, att.time_maa_w2,
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
att.ln_x.weight, att.ln_x.bias)
Expand Down Expand Up @@ -259,8 +265,10 @@ def forward(self, token, state):
out, state = init_out.clone(), copy.deepcopy(init_state)

min_time = 1e10
min_time_all = 1e10

for i in range(LENGTH_PER_TRIAL):
t00 = time.time()
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
try:
Expand All @@ -276,7 +284,8 @@ def forward(self, token, state):

t1 = time.time()
min_time = min(min_time, t1 - t0)
min_time_all = min(min_time_all, t1 - t00)

print(f'\n[{round(1/min_time,2)} token/s (ignore tokenizer & sampling)]', end='')
print(f'\n[{round(1/min_time_all,2)} (real) / {round(1/min_time,2)} token/s (ignore tokenizer & sampling)]', end='')

print('\n')

0 comments on commit fb54f8c

Please sign in to comment.