Skip to content

Commit

Permalink
v7 pip
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Dec 6, 2024
1 parent d67bbb3 commit f431ada
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 12 deletions.
23 changes: 13 additions & 10 deletions rwkv_pip_package/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
The RWKV Language Model

https://github.com/BlinkDL/ChatRWKV

https://github.com/BlinkDL/RWKV-LM

https://github.com/BlinkDL/ChatRWKV

```python
# set these before import RWKV
os.environ['RWKV_JIT_ON'] = '1'
# !!! set these before import RWKV !!!
os.environ['RWKV_JIT_ON'] = '1' # '1' for better speed
os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

########################################################################################################
#
# Use '/' in model path, instead of '\'. Use ctx4096 models if you need long ctx.
#
# fp16 = good for GPU (!!! DOES NOT support CPU !!!)
# fp16 = good for GPU
# fp32 = good for CPU
# bf16 = worse accuracy, supports CPU
# bf16 = supports CPU
# xxxi8 (example: fp16i8, fp32i8) = xxx with int8 quantization to save 50% VRAM/RAM, slower, slightly less accuracy
#
# We consider [ln_out+head] to be an extra layer, so L12-D768 (169M) has "13" layers, L24-D2048 (1.5B) has "25" layers, etc.
Expand Down Expand Up @@ -51,9 +51,10 @@ from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS

# download models: https://huggingface.co/BlinkDL
model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32')
pipeline = PIPELINE(model, "20B_tokenizer.json") # 20B_tokenizer.json is in https://github.com/BlinkDL/ChatRWKV
# use pipeline = PIPELINE(model, "rwkv_vocab_v20230424") for rwkv "world" models
model = RWKV(model='RWKV-x060-World-1B6-v2.1-20240328-ctx4096', strategy='cpu fp32')

pipeline = PIPELINE(model, "rwkv_vocab_v20230424") # for "world" models
# pipeline = PIPELINE(model, "20B_tokenizer.json") # for "pile" models, 20B_tokenizer.json is in https://github.com/BlinkDL/ChatRWKV

ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
print(ctx, end='')
Expand All @@ -68,13 +69,15 @@ args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.7, top_k = 100, # top_k = 0 th
alpha_frequency = 0.25,
alpha_presence = 0.25,
alpha_decay = 0.996, # gradually decay the penalty
token_ban = [0], # ban the generation of some tokens
token_ban = [], # ban the generation of some tokens
token_stop = [], # stop generation whenever you see any token here
chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)

pipeline.generate(ctx, token_count=200, args=args, callback=my_print)
print('\n')

# !!! model.forward(tokens, state) will modify state in-place !!!

out, state = model.forward([187, 510, 1563, 310, 247], None)
print(out.detach().cpu().numpy()) # get logits
out, state = model.forward([187, 510], None)
Expand Down
2 changes: 1 addition & 1 deletion rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.8.26"
version = "0.8.27"
authors = [
{ name="Bo PENG" },
]
Expand Down
77 changes: 77 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/wkv7s.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"

typedef at::Half fp16;
typedef at::BFloat16 bf16;
typedef float fp32;

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!

float state[_N_];
#pragma unroll
for (int j = 0; j < _N_; j++)
state[j] = _state[j];

__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];

for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();

float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}

float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
#pragma unroll
for (int j = 0; j < _N_; j++)
_state[j] = state[j];
}

void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
assert(B == 1); // only for B=1
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
{
assert(H*_N_ == C);
assert(B == 1); // only for B=1
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
{
assert(H*_N_ == C);
assert(B == 1); // only for B=1
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}
26 changes: 26 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/wkv7s_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <torch/extension.h>
#include "ATen/ATen.h"

typedef at::Half fp16;
typedef at::BFloat16 bf16;
typedef float fp32;

void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);

void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
}
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
}

TORCH_LIBRARY(wkv7s, m) {
m.def("forward_bf16", forward_bf16);
m.def("forward_fp16", forward_fp16);
m.def("forward_fp32", forward_fp32);
}
Loading

0 comments on commit f431ada

Please sign in to comment.