Skip to content

Commit

Permalink
0.8.26
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Apr 26, 2024
1 parent 661a1c4 commit 8c79567
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.8.25"
version = "0.8.26"
authors = [
{ name="Bo PENG" },
]
Expand Down
17 changes: 12 additions & 5 deletions rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
print_need_newline = False

REAL_TIME_FIRST = False
args.time_state = False
for x in list(w.keys()):
if '.time_faaaa' in x: REAL_TIME_FIRST = True
if '.time_state' in x: args.time_state = True
if REAL_TIME_FIRST:
w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
self.w = w
Expand Down Expand Up @@ -436,10 +438,12 @@ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit =
torch.cuda.empty_cache()

shape = [i for i in w[x].shape if i != 1]
if len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
if len(shape) > 2:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
elif len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
else:
shape = f" {str(shape[0]).rjust(5)} "
shape = f" {str(shape[0]).rjust(5)} "
if layer_id == 0 or layer_id >= args.n_layer-1:
if print_need_newline:
prxxx('\n', end = '')
Expand Down Expand Up @@ -498,7 +502,7 @@ def forward(ctx, B, T, C, H, state, r, k, v, w, u):
if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
HEAD_SIZE = args.n_att // args.n_head
rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])

class RWKV_6(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -1024,7 +1028,10 @@ def forward(self, tokens, state, full_output=False):
dev = dd.device
atype = dd.atype
state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
if args.time_state:
state[i*3+1] = w[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
else:
state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()

seq_mode = len(tokens) > 1
Expand Down

0 comments on commit 8c79567

Please sign in to comment.