Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InfiniteBench for long context benchmarking #2421

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions benchmark/infinitebench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
## Run benchmark

### SGLang
#### Set up environment
```bash
conda create -n sglang python=3.10
conda activate sglang
pip install --upgrade pip
pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/
```

#### Launch server and run eval
```
python -m sglang.launch_server --model-path gradientai/Llama-3-8B-Instruct-Gradient-1048k --port 30000 --max-total-tokens 131072
```

In a separate terminal, run eval on the first 10 samples as follows
```
python bench_sglang.py --task passkey --num-samples 10
```

### TensorRT
The following evaluation with TensorRT has been tested with 1xH100 (80 GB SXM5).

#### Set up enviroment
```bash
conda create -n tensorrt python=3.10
conda activate tensorrt
conda install -c conda-forge mpi4py openmpi
sudo apt-get -y install libopenmpi-dev && pip install tensorrt_llm==0.15.0
```

```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
pip install --upgrade -r requirements-dev.txt
cd examples/llama
```

#### Download dataset
The following commands will download the dataset files in `./data` directory, which is hardcoded in the script.
```bash
URL="https://raw.githubusercontent.com/OpenBMB/InfiniteBench/51d9b37b0f1790ead936df2243abbf7f0420e439/scripts/download_dataset.sh"
wget $URL -O download_infinitebench.sh
bash download_infinitebench.sh
```

#### Prepare checkpoint
```bash
sudo apt install git-lfs
git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/

python convert_checkpoint.py \
--model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--output_dir /tmp/llama-3-8B-1048k/trt_ckpts \
--dtype float16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the dtype specified in the model's config.json is bfloat16. Could you please explain why float16 is being specified here?


```

#### Build engine and run eval
```bash
python -m tensorrt_llm.commands.build \
--checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \
--output_dir /tmp/llama-3-8B-1048k/trt_engines \
--gemm_plugin float16 \
--max_num_tokens 4096 \
--max_input_len 131072 \
--max_seq_len 131082 \
--use_paged_context_fmha enable

python ../eval_long_context.py \
--task passkey \
--engine_dir /tmp/llama-3-8B-1048k/trt_engines \
--tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \
--stop_idx 10 \
--max_input_length 131072 \
--enable_chunked_context \
--max_tokens_in_paged_kv_cache 131136 \
--data_dir ./data \
--output_dir ./
```
122 changes: 122 additions & 0 deletions benchmark/infinitebench/bench_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/eval_long_context.py
import argparse
import json
import os
import time

from compute_scores import compute_scores
from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS, create_prompt, get_answer

import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, read_jsonl


def validate_args(args):
assert args.task in ["passkey", "kv_retrieval"], f"Invalid task: {args.task}"


def write_answers(filename, model_id, results):
with open(os.path.expanduser(filename), "w") as fout:
for i in range(len(results)):
ans_json = {
"question_id": results[i]["question_id"],
"model_id": model_id,
"prediction": results[i]["prediction"],
"ground_truth": results[i]["ground_truth"],
}
fout.write(json.dumps(ans_json) + "\n")


@sgl.function
def infinitebench(s, question, max_tokens):
s += question
s += sgl.gen(
"answer",
max_tokens=max_tokens,
)


def main(args):
validate_args(args)

# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)

# Download and load data
data_name = args.task
data_url = f"https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/{data_name}.jsonl"
max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name] # max output length

filename = download_and_cache_file(data_url)
lines = list(read_jsonl(filename))
if args.num_samples is None:
args.num_samples = len(lines)

# Construct prompts
questions = []
labels = []
for i in range(len(lines[: args.num_samples])):
questions.append(create_prompt(lines[i], data_name, os.path.dirname(filename)))
labels.append(get_answer(lines[i], data_name))
arguments = [{"question": q, "max_tokens": max_tokens} for q in questions]

# Run requests
tic = time.time()
results = infinitebench.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic

# Compute scores
results = [
{
"ground_truth": label,
"prediction": result["answer"],
"question_id": line["id"],
}
for line, label, result in zip(lines, labels, results)
]
acc = compute_scores(results, args.task)
print(f"#questions: {len(questions)}, Latency: {latency:.2f}, Accuracy: {acc:.3f}")

# Write results to file
model_id = backend.model_info["model_path"]
answer_file = f"tmp_output_{data_name}_{args.backend}.txt"
write_answers(answer_file, model_id, results)

with open(args.result_file, "a") as fout:
value = {
"task": args.task,
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(questions),
"other": {
"num_questions": len(questions),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--task", type=str, choices=["passkey", "kv_retrieval"], required=True
)
parser.add_argument(
"--num-samples",
type=int,
default=None,
help="Number of samples from the beginning of dataset to use for eval.",
)
args = add_common_sglang_args_and_parse(parser)
main(args)
Loading