diff --git a/python/sglang/llama3_eval.py b/python/sglang/llama3_eval.py new file mode 100644 index 00000000000..cacad0c390e --- /dev/null +++ b/python/sglang/llama3_eval.py @@ -0,0 +1,316 @@ +# Adapt from https://github.com/fw-ai/llm_eval_meta + +import argparse +import asyncio +import os +import pickle +import re +import shutil +from collections import defaultdict +from dataclasses import dataclass + +import httpx +import numpy as np +import openai +import transformers +from datasets import load_dataset +from openai import AsyncOpenAI +from tqdm import tqdm + +# Mapping providers to their clients and models +provider_to_models = { + "b10": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "oai": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "sgl": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, +} + + +async def fetch_responses( + client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens +): + output_file = os.path.join(output_dir, f"response_{index}.pkl") + if os.path.exists(output_file): + print(f"File {output_file} already exists, skipping.") + return + + async with semaphore: + response = await client.completions.create( + model=provider_to_models[provider][model_size], + prompt=prompt, + temperature=0.0, + max_tokens=max_tokens, + ) + if isinstance(response, openai.BadRequestError): + with open(output_file, "wb") as f: + pickle.dump("bad_response", f) + assert isinstance(response, openai.types.completion.Completion) + # Save response to a file + with open(output_file, "wb") as f: + pickle.dump(response, f) + + +TASK_TO_MAX_TOKENS = { + "evals__mmlu__details": 1, + "evals__mmlu__0_shot__cot__details": 1024, + # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing + "evals__mmlu_pro__details": 2048, + "evals__gsm8k__details": 1024, +} + +TASK_TO_EVAL_SET = { + "mmlu": "evals__mmlu__details", + "mmlu_0_shot": "evals__mmlu__0_shot__cot__details", + "mmlu_pro": "evals__mmlu_pro__details", + "gsm8k": "evals__gsm8k__details", +} + + +class CustomAsyncHTTPXClient(httpx.AsyncClient): + async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response: + request.url = httpx.URL( + f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict" + ) + return await super().send(request, *args, **kwargs) + + +def get_client(provider): + if provider not in "b10": + if os.getenv("OPENAI_API_KEY") == None: + os.environ["OPENAI_API_KEY"] = "EMPTY" + return { + "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"), + "b10": AsyncOpenAI( + api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}", + base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict", + http_client=CustomAsyncHTTPXClient(), + ), + "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"), + }[provider] + + +# Define the benchmark function +async def benchmark(args): + ds = load_dataset( + "meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}", + ) + semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks + + if args.num_examples is None: + args.num_examples = len(ds["latest"]["input_final_prompts"]) + prompts = ds["latest"]["input_final_prompts"][: args.num_examples] + + # Create the output directory if it does not exist + os.makedirs(args.output_dir, exist_ok=True) + + tasks = [] + # Create the tasks with tqdm progress bar + max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]] + client = get_client(args.provider) + for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")): + tasks.append( + asyncio.create_task( + fetch_responses( + client, + f"<|begin_of_text|>{prompt[0]}", + semaphore, + idx, + args.provider, + args.model_size, + args.output_dir, + max_tokens=max_tokens, + ) + ) + ) + + # Run the tasks with tqdm progress bar + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" + ): + await future + + +def get_mmlu_answer(response): + if response is not None: + return response.choices[0].text.lstrip().rstrip().upper().replace(".", "") + return None + + +def get_mmlu_cot_answer(response): + pattern = r"The best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "").replace("*", "") + + pattern = r"the best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"The correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"the correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + +def get_answer_gsm8k(response): + pattern = r"The final answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + s = match.group(1) + for ok_symbol in ["%", "$"]: + s = s.replace(ok_symbol, "") + return s + + +TASK_TO_ANSWER_EXTRACTOR = { + "evals__mmlu__details": get_mmlu_answer, + "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer, + "evals__gsm8k__details": get_answer_gsm8k, + "evals__mmlu_pro__details": get_mmlu_cot_answer, +} + + +def get_dataset_from_task(task, response_path, model_size): + ds_405b = load_dataset( + f"meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{task}", + ) + ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]] + + if "70b" in model_size or "8b" in model_size: + if "70" in model_size: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-70B-Instruct-evals", + f"Llama-3.1-70B-Instruct-{task}", + ) + else: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-8B-Instruct-evals", + f"Llama-3.1-8B-Instruct-{task}", + ) + + hash_to_row = {} + for row in ref_model_ds["latest"]: + hash_to_row[row["input_final_prompts_hash"][0]] = row + reordered_rows = [] + for prompt_hash in ds_405b_hash_order: + reordered_rows.append(hash_to_row[prompt_hash]) + ref_model_ds["latest"] = reordered_rows + return ref_model_ds + + return ds_405b + + +def analyze(task, response_path, model_size): + ds = get_dataset_from_task(task, response_path, model_size) + + responses = [] + total = len(ds["latest"]) + + for i in range(0, total): + response = pickle.load( + open(os.path.join(response_path, f"response_{i}.pkl"), "rb") + ) + responses.append(response) + + @dataclass + class Stats: + correct: int = 0 + total: int = 0 + meta_correct: int = 0 + + average: float = None + + subtask_name_to_stats = defaultdict(lambda: Stats()) + + for response, ds_row in zip(responses, ds["latest"]): + model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response) + + subtask = ds_row["subtask_name"] + + is_eval_correct = model_answer in ds_row["input_correct_responses"] + if is_eval_correct: + subtask_name_to_stats[subtask].correct += 1 + + if ds_row["is_correct"]: + subtask_name_to_stats[subtask].meta_correct += 1 + + subtask_name_to_stats[subtask].total += 1 + + micro_stats = Stats() + for subtask, stats in subtask_name_to_stats.items(): + stats.average = stats.correct / stats.total + stats.meta_average = stats.meta_correct / stats.total + + micro_stats.correct += stats.correct + micro_stats.total += stats.total + micro_stats.meta_correct += stats.meta_correct + + micro_stats.average = micro_stats.correct / micro_stats.total + micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total + + print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()])) + print( + "Meta Macro average", + np.mean([x.meta_average for x in subtask_name_to_stats.values()]), + ) + print("Micro average", micro_stats.average) + print("Meta Micro average", micro_stats.meta_average) + + +# Entry point for the script +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to run model with specified parameters." + ) + parser.add_argument( + "--model-size", + type=str, + default="8b", + help="Size of the model (e.g., 8b or 70b)", + ) + parser.add_argument( + "--provider", + type=str, + default="sgl", + help="Provider name (e.g., sgl, oai, b10)", + ) + parser.add_argument( + "--task", + type=str, + required=True, + help="Task (e.g., mmlu, mmlu_0_shot, mmlu_pro, gsm8k)", + ) + parser.add_argument( + "--num-examples", type=int, default=None, help="Number of examples to process" + ) + parser.add_argument("--concurrency", type=int, default=16) + parser.add_argument( + "--output-dir", + type=str, + default="tmp-output-dir", + help="Directory to save responses", + ) + + args = parser.parse_args() + asyncio.run(benchmark(args)) + analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size) + shutil.rmtree("tmp-output-dir", ignore_errors=True)