diff --git a/API_DEMO_CHAT.py b/API_DEMO_CHAT.py index 33baadd..dad4a45 100644 --- a/API_DEMO_CHAT.py +++ b/API_DEMO_CHAT.py @@ -13,7 +13,7 @@ torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_JIT_ON"] = "1" -os.environ["RWKV_CUDA_ON"] = "0" # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries +os.environ["RWKV_CUDA_ON"] = "1" # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries from rwkv.model import RWKV from rwkv.utils import PIPELINE @@ -24,7 +24,7 @@ args.strategy = "cuda fp16" # use CUDA, fp16 -args.MODEL_NAME = "E://RWKV-Runner//models//RWKV-5-World-1B5-v2-20231025-ctx4096" +args.MODEL_NAME = "E://RWKV-Runner//models//rwkv-final-v6-2.1-7b" GEN_TEMP = 1.0 GEN_TOP_P = 0.3