Skip to content

Commit

Permalink
state loading demo
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jul 1, 2024
1 parent 7b23d3b commit a27ceb3
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions API_DEMO_CHAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "1" # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
os.environ["RWKV_CUDA_ON"] = "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!

from rwkv.model import RWKV
from rwkv.utils import PIPELINE
Expand All @@ -24,14 +24,28 @@

args.strategy = "cuda fp16" # use CUDA, fp16

args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-7b"
args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-1b6"


########################################################################################################
# STATE_NAME = None # use vanilla zero initial state?

# use custom state? much better chat results (download from https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main)
# note: this is English Single-round QA state (will forget what you previously say)
STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"
########################################################################################################

GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.0
GEN_alpha_frequency = 1.0
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996

if STATE_NAME != None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3

CHUNK_LEN = 256 # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)

########################################################################################################
Expand All @@ -43,6 +57,18 @@
model_tokens = []
model_state = None

if STATE_NAME != None: # load custom state
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)

def run_rnn(ctx):
global model_tokens, model_state
Expand Down

0 comments on commit a27ceb3

Please sign in to comment.