From 95d3671ed1ec6919eb52a748d7f3de128f3d5897 Mon Sep 17 00:00:00 2001 From: Bo Date: Wed, 3 Jul 2024 02:07:54 +0800 Subject: [PATCH] accurate benchmark --- RWKV_v6_demo_cuda_bf16.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/RWKV_v6_demo_cuda_bf16.py b/RWKV_v6_demo_cuda_bf16.py index 75c3d5c..09e7d5d 100644 --- a/RWKV_v6_demo_cuda_bf16.py +++ b/RWKV_v6_demo_cuda_bf16.py @@ -270,10 +270,10 @@ def decode(self, tokens): min_time = 1e10 min_time_all = 1e10 - t000 = time.time() + t000 = time.perf_counter() for i in range(LENGTH_PER_TRIAL): - t00 = time.time() + t00 = time.perf_counter() token = sample_logits(out, TEMPERATURE, TOP_P) all_tokens += [token] try: @@ -283,14 +283,14 @@ def decode(self, tokens): out_last = i + 1 except: pass - t0 = time.time() + t0 = time.perf_counter() out, state = model.forward(token, state) - t1 = time.time() + t1 = time.perf_counter() min_time = min(min_time, t1 - t0) min_time_all = min(min_time_all, t1 - t00) - print(f'\n[ {round(1/min_time_all,2)} (real) / {round(1/min_time,2)} (ignore sampling & tokenizer) token/s = {round(time.time()-t000,3)}s ]', end='') + print(f'\n[ {round(1/min_time_all,2)} (real) / {round(1/min_time,2)} (ignore sampling & tokenizer) token/s = {round(time.perf_counter()-t000,3)}s ]', end='') print('\n')