Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add llama3 eval #2515

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 316 additions & 0 deletions python/sglang/llama3_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# Adapt from https://github.com/fw-ai/llm_eval_meta

import argparse
import asyncio
import os
import pickle
import re
import shutil
from collections import defaultdict
from dataclasses import dataclass

import httpx
import numpy as np
import openai
import transformers
from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm import tqdm

# Mapping providers to their clients and models
provider_to_models = {
"b10": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"oai": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"sgl": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
}


async def fetch_responses(
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
):
output_file = os.path.join(output_dir, f"response_{index}.pkl")
if os.path.exists(output_file):
print(f"File {output_file} already exists, skipping.")
return

async with semaphore:
response = await client.completions.create(
model=provider_to_models[provider][model_size],
prompt=prompt,
temperature=0.0,
max_tokens=max_tokens,
)
if isinstance(response, openai.BadRequestError):
with open(output_file, "wb") as f:
pickle.dump("bad_response", f)
assert isinstance(response, openai.types.completion.Completion)
# Save response to a file
with open(output_file, "wb") as f:
pickle.dump(response, f)


TASK_TO_MAX_TOKENS = {
"evals__mmlu__details": 1,
"evals__mmlu__0_shot__cot__details": 1024,
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
"evals__mmlu_pro__details": 2048,
"evals__gsm8k__details": 1024,
}

TASK_TO_EVAL_SET = {
"mmlu": "evals__mmlu__details",
"mmlu_0_shot": "evals__mmlu__0_shot__cot__details",
"mmlu_pro": "evals__mmlu_pro__details",
"gsm8k": "evals__gsm8k__details",
}


class CustomAsyncHTTPXClient(httpx.AsyncClient):
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
request.url = httpx.URL(
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
)
return await super().send(request, *args, **kwargs)


def get_client(provider):
if provider not in "b10":
if os.getenv("OPENAI_API_KEY") == None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return {
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
"b10": AsyncOpenAI(
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
http_client=CustomAsyncHTTPXClient(),
),
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
}[provider]


# Define the benchmark function
async def benchmark(args):
ds = load_dataset(
"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
)
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks

if args.num_examples is None:
args.num_examples = len(ds["latest"]["input_final_prompts"])
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]

# Create the output directory if it does not exist
os.makedirs(args.output_dir, exist_ok=True)

tasks = []
# Create the tasks with tqdm progress bar
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
client = get_client(args.provider)
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
tasks.append(
asyncio.create_task(
fetch_responses(
client,
f"<|begin_of_text|>{prompt[0]}",
semaphore,
idx,
args.provider,
args.model_size,
args.output_dir,
max_tokens=max_tokens,
)
)
)

# Run the tasks with tqdm progress bar
for future in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
):
await future


def get_mmlu_answer(response):
if response is not None:
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
return None


def get_mmlu_cot_answer(response):
pattern = r"The best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "").replace("*", "")

pattern = r"the best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")

pattern = r"The correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")

pattern = r"the correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")


def get_answer_gsm8k(response):
pattern = r"The final answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
s = match.group(1)
for ok_symbol in ["%", "$"]:
s = s.replace(ok_symbol, "")
return s


TASK_TO_ANSWER_EXTRACTOR = {
"evals__mmlu__details": get_mmlu_answer,
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
"evals__gsm8k__details": get_answer_gsm8k,
"evals__mmlu_pro__details": get_mmlu_cot_answer,
}


def get_dataset_from_task(task, response_path, model_size):
ds_405b = load_dataset(
f"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{task}",
)
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]

if "70b" in model_size or "8b" in model_size:
if "70" in model_size:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-70B-Instruct-evals",
f"Llama-3.1-70B-Instruct-{task}",
)
else:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-8B-Instruct-evals",
f"Llama-3.1-8B-Instruct-{task}",
)

hash_to_row = {}
for row in ref_model_ds["latest"]:
hash_to_row[row["input_final_prompts_hash"][0]] = row
reordered_rows = []
for prompt_hash in ds_405b_hash_order:
reordered_rows.append(hash_to_row[prompt_hash])
ref_model_ds["latest"] = reordered_rows
return ref_model_ds

return ds_405b


def analyze(task, response_path, model_size):
ds = get_dataset_from_task(task, response_path, model_size)

responses = []
total = len(ds["latest"])

for i in range(0, total):
response = pickle.load(
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
)
responses.append(response)

@dataclass
class Stats:
correct: int = 0
total: int = 0
meta_correct: int = 0

average: float = None

subtask_name_to_stats = defaultdict(lambda: Stats())

for response, ds_row in zip(responses, ds["latest"]):
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)

subtask = ds_row["subtask_name"]

is_eval_correct = model_answer in ds_row["input_correct_responses"]
if is_eval_correct:
subtask_name_to_stats[subtask].correct += 1

if ds_row["is_correct"]:
subtask_name_to_stats[subtask].meta_correct += 1

subtask_name_to_stats[subtask].total += 1

micro_stats = Stats()
for subtask, stats in subtask_name_to_stats.items():
stats.average = stats.correct / stats.total
stats.meta_average = stats.meta_correct / stats.total

micro_stats.correct += stats.correct
micro_stats.total += stats.total
micro_stats.meta_correct += stats.meta_correct

micro_stats.average = micro_stats.correct / micro_stats.total
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total

print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
print(
"Meta Macro average",
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
)
print("Micro average", micro_stats.average)
print("Meta Micro average", micro_stats.meta_average)


# Entry point for the script
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to run model with specified parameters."
)
parser.add_argument(
"--model-size",
type=str,
default="8b",
help="Size of the model (e.g., 8b or 70b)",
)
parser.add_argument(
"--provider",
type=str,
default="sgl",
help="Provider name (e.g., sgl, oai, b10)",
)
parser.add_argument(
"--task",
type=str,
required=True,
help="Task (e.g., mmlu, mmlu_0_shot, mmlu_pro, gsm8k)",
)
parser.add_argument(
"--num-examples", type=int, default=None, help="Number of examples to process"
)
parser.add_argument("--concurrency", type=int, default=16)
parser.add_argument(
"--output-dir",
type=str,
default="tmp-output-dir",
help="Directory to save responses",
)

args = parser.parse_args()
asyncio.run(benchmark(args))
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
shutil.rmtree("tmp-output-dir", ignore_errors=True)
Loading