Skip to content

Commit

Permalink
faster
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 10, 2024
1 parent 854d1e3 commit 251e94c
Showing 1 changed file with 153 additions and 144 deletions.
297 changes: 153 additions & 144 deletions RWKV_v6_demo_cuda_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,116 +21,23 @@

########################################################################################################

class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x

self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)

# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]

for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)

def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]

if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)

return tokens

def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))

def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))

def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')

########################################################################################################

@torch.jit.script
def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):
probs = F.softmax(logits.float(), dim=-1)
sorted_probs, sorted_ids = torch.sort(probs, descending=True)

if top_k > 0:
probs[sorted_ids[top_k:]] = 0

if top_p < 1:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_index = torch.searchsorted(cumulative_probs, top_p)
cutoff = sorted_probs[cutoff_index]
probs[probs < cutoff] = 0

if top_p > 0:
idx = torch.where(probs == cutoff)[0]
probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx)
# assert abs(torch.sum(probs).item() - top_p) < 1e-6

if temperature != 1.0:
probs = probs ** (1.0 / temperature)

return torch.multinomial(probs, num_samples=1).item()

########################################################################################################

# cuda bf16 inference

tokenizer = RWKV_TOKENIZER("tokenizer/rwkv_vocab_v20230424.txt")

args = types.SimpleNamespace()
args.tokenizer = "tokenizer/rwkv_vocab_v20230424.txt"
args.MODEL_NAME = '/home/rwkv/rwkv-final-v6-2.1-3b'
# args.MODEL_NAME = '/mnt/program/rwkv-final-v6-2.1-3b'
args.n_layer = 32
args.n_embd = 2560
args.vocab_size = 65536
args.head_size = 64

context = "A few light taps upon the pane made her turn to the window. It had begun to snow again."
context = "\nA few light taps upon the pane made her turn to the window. It had begun to snow again."
# context = "\n北京"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.3
TOP_P = 0

########################################################################################################

class RWKV_RNN(MyModule):
def __init__(self, args):
Expand Down Expand Up @@ -162,48 +69,6 @@ def __init__(self, args):
self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head
assert self.head_size == args.head_size

@MyFunction
def channel_mixing(self, x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
sx = x_prev - x
k = x + sx * time_maa_k
r = x + sx * time_maa_r

r = torch.sigmoid(rw @ r)
k = torch.relu(kw @ k) ** 2

return r * (vw @ k), x

@MyFunction
def time_mixing(self, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
N = self.head_size

sx = x_prev - x
xxx = x + sx * maa_x # C
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L
xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C
xxx = xxx + maa_wkvrg
xxx = xxx * sx.expand(5, -1) + x.expand(5, -1)
w, k, v, r, g = xxx.unbind(dim=0)

w = torch.tanh(w @ td_w1) @ td_w2
w = w.float() + time_decay
# assert w.dtype == torch.float
w = torch.exp(-torch.exp(w))

k = (kw @ k).view(H, N, 1)
v = (vw @ v).view(H, 1, N)
r = (rw @ r).view(H, 1, N)
g = F.silu(gw @ g)

kv = (k @ v).float()
out = r @ (time_faaaa * kv + state).to(torch.bfloat16)

state = kv + w.view(H, N, 1) * state

out = F.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
return ow @ (out * g), x, state

@MyFunction
def forward(self, token:int, state:List[torch.Tensor]):
with torch.no_grad():
Expand All @@ -215,15 +80,15 @@ def forward(self, token:int, state:List[torch.Tensor]):
ffn = f'blocks.{i}.ffn.'

xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
xx, state[i*3+0], state[i*3+1] = self.time_mixing(xx, state[i*3+0], state[i*3+1],
xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1],
z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'],
z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'],
z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'],
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
x = x + xx

xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
xx, state[i*3+2] = self.channel_mixing(xx, state[i*3+2],
xx, state[i*3+2] = channel_mixing(xx, state[i*3+2],
z[ffn+'time_maa_k'], z[ffn+'time_maa_r'],
z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight'])
x = x + xx
Expand All @@ -232,6 +97,146 @@ def forward(self, token:int, state:List[torch.Tensor]):
x = z['head.weight'] @ x
return x, state

########################################################################################################

def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
sx = x_prev - x
xxx = x + sx * maa_x # C
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L
xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C
xxx = xxx + maa_wkvrg
xxx = xxx * sx.expand(5, -1) + x.expand(5, -1)
w, k, v, r, g = xxx.unbind(dim=0)

w = torch.tanh(w @ td_w1) @ td_w2
w = w.float() + time_decay
# assert w.dtype == torch.float
w = torch.exp(-torch.exp(w))

k = (kw @ k).view(H, N, 1)
v = (vw @ v).view(H, 1, N)
r = (rw @ r).view(H, 1, N)
g = torch.nn.functional.silu(gw @ g)

kv = (k @ v).float()
out = r @ (time_faaaa * kv + state).to(torch.bfloat16)

state = kv + w.view(H, N, 1) * state

out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
return ow @ (out * g), x, state
time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)

########################################################################################################

def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
sx = x_prev - x
k = x + sx * time_maa_k
r = x + sx * time_maa_r

r = torch.sigmoid(rw @ r)
k = torch.relu(kw @ k) ** 2

return r * (vw @ k), x
channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)

########################################################################################################

@torch.jit.script
def sample_logits(logits, temperature:float=1.0, top_p:float=1.0, top_k:int=0):
probs = F.softmax(logits.float(), dim=-1)
sorted_probs, sorted_ids = torch.sort(probs, descending=True)

if top_k > 0:
probs[sorted_ids[top_k:]] = 0

if top_p < 1:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cutoff_index = torch.searchsorted(cumulative_probs, top_p)
cutoff = sorted_probs[cutoff_index]
probs[probs < cutoff] = 0

if top_p > 0:
idx = torch.where(probs == cutoff)[0]
probs[idx] = cutoff + (top_p - torch.sum(probs).item()) / len(idx)
# assert abs(torch.sum(probs).item() - top_p) < 1e-6

if temperature != 1.0:
probs = probs ** (1.0 / temperature)

return torch.multinomial(probs, num_samples=1).item()

########################################################################################################

class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x

self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)

# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]

for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)

def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]

if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)

return tokens

def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))

def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))

def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')

########################################################################################################

tokenizer = RWKV_TOKENIZER(args.tokenizer)

print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

Expand All @@ -246,6 +251,8 @@ def forward(self, token:int, state:List[torch.Tensor]):
for token in tokenizer.encode(context):
init_out, init_state = model.forward(token, init_state)

########################################################################################################

for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
Expand All @@ -255,6 +262,8 @@ def forward(self, token:int, state:List[torch.Tensor]):
min_time = 1e10
min_time_all = 1e10

t000 = time.time()

for i in range(LENGTH_PER_TRIAL):
t00 = time.time()
token = sample_logits(out, TEMPERATURE, TOP_P)
Expand All @@ -274,6 +283,6 @@ def forward(self, token:int, state:List[torch.Tensor]):
min_time = min(min_time, t1 - t0)
min_time_all = min(min_time_all, t1 - t00)

print(f'\n[{round(1/min_time_all,2)} (real) / {round(1/min_time,2)} token/s (ignore tokenizer & sampling)]', end='')
print(f'\n[ {round(1/min_time_all,2)} (real) / {round(1/min_time,2)} (ignore sampling & tokenizer) token/s = {round(time.time()-t000,3)}s ]', end='')

print('\n')

0 comments on commit 251e94c

Please sign in to comment.