Skip to content

Commit

Permalink
more instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
Hui Kang Tong committed Sep 16, 2024
1 parent c885d7e commit c9f9e3d
Show file tree
Hide file tree
Showing 13 changed files with 6,977 additions and 7,129 deletions.
26 changes: 10 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Automatic Prompt Engineering for classification

Given only (text -> label), this notebook generates and optimizes system and user prompts.
Given only (text:str, label:str), this [notebook](https://nbviewer.org/github/tonghuikang/automatic-prompt-engineer/blob/master/classification.ipynb) generates and optimizes system and user prompts.

This is how classification is intended to be done.
- (system prompt, user prompt prefix + text + user prompt suffix) -Haiku-> bot response -extract-> label
Expand All @@ -13,9 +13,6 @@ The notebook will produce
You can simply run this notebook with just
- an Anthropic API key and an OpenAI API key

If you want to change the classification task, you will need to
- provide a dataset (text -> label)

This is how prompt tuning is done
- Sample from the full dataset.
- Haiku takes in (system prompt, user prompt prefix + text + user prompt suffix) and produces model_response
Expand All @@ -31,7 +28,7 @@ You will need to have these Python modules installed
- openai

This notebook will also produce
- The [classification](https://tonghuikang.github.io/automatic-prompt-engineer/html_output/iteration-classification-002.html) (or just the [mistakes](https://tonghuikang.github.io/automatic-prompt-engineer/html_output/iteration-classification-002-diff.html)) at each iteration of the prompt.
- The [classification](https://tonghuikang.github.io/automatic-prompt-engineer/html_output/iteration-classification-003.html) (or just the [mistakes](https://tonghuikang.github.io/automatic-prompt-engineer/html_output/iteration-classification-003-diff.html)) at each iteration of the prompt.
- The [history](https://tonghuikang.github.io/automatic-prompt-engineer/html_output/prompt-history-classification.html) of the prompt and relevant metrics.
- (These will be saved locally as HTML files)

Expand All @@ -46,6 +43,9 @@ I took inspiration from these works.

# Design Decisions

- The only input is dataset: list[tuple[str, str]] of text and labels.
Nothing else.
It is up to o1 to understand the assignment.
- I require the LLM to produce the reasoning.
Having the reasoning provides visibility to the thought process, which helps with improving the prompt.
- I minimized the amount of required packages.
Expand All @@ -60,11 +60,13 @@ I took inspiration from these works.

### Run this notebook

Just add your with your Anthropic API key and you should be able to run `classification.ipynb`.
Just add your with your Anthropic API key and OpenAI API key and you should be able to run `classification.ipynb`.

You will need to clone the repository for a copy of `qiqc_truncated.csv`.

It will cost approximately 40 cents per iteration, and `NUM_ITERATIONS` is 5.
It will cost approximately 50 cents per run with `NUM_ITERATIONS` of 5.

Try printing what exactly is sent to o1-mini.

You may need to change `NUM_PARALLEL_FORWARD_PASS_API_CALLS` depending on the rate limit your API key can afford.

Expand All @@ -81,7 +83,7 @@ I recommend initializing `dataset` in the notebook with probably a copy of 50 sa
### Change the sampling method

For the forward pass, we currently sample 100 positive and negative samples.
For the mistake summary, we sample 10 false positives and false negatives, 5 true positives, and 5 true negatives.
For the feedback, we sample 10 mistakes of each label, and 5 correct classification of each label.
Otherwise, the sample is chosen totally at random.

We can improve the sampling method so that the model learns better.
Expand All @@ -107,21 +109,13 @@ Claude allows you to specify the prefix of the response (see [Putting words in C

Besides `system_prompt`, `user_prompt_prefix`, and `user_prompt_suffix`, you can also add `bot_reply_prefix` as a field that o1 should produce.

You will need to update the `user_message` in `update_model_parameters`.
You also need to describe in `PROMPT_UPDATE_SYSTEM_PROMPT` what `bot_reply_prefix` is for.


### Try a different classification task

You can try out a different dataset.

Another dataset I recommend is the [Quora Question Pairs dataset](https://www.kaggle.com/c/quora-question-pairs/data).

You will need to need to
- Change how the dataset is being loaded (since each sample is now a pair of questions, you need to change the delimiters)
- Change `predict_from_final_layer` to match a different string
- Change `PROMPT_UPDATE_SYSTEM_PROMPT` to describe what is being matched


### Try a non-classification task

Expand Down
90 changes: 52 additions & 38 deletions classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"id": "d4c8e450",
"metadata": {},
"source": [
"# Use your Anthropic API key here"
"# Use your Anthropic API key and OpenAI API key here"
]
},
{
Expand All @@ -89,7 +89,6 @@
"outputs": [],
"source": [
"anthropic_api_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n",
"# anthropic_api_key = \"sk-ant-\"\n",
"anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)\n",
"\n",
"openai_api_key = os.environ.get(\"OPENAI_API_KEY\")\n",
Expand Down Expand Up @@ -117,7 +116,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "147ef4bb",
"id": "39725119",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -152,7 +151,8 @@
"metadata": {},
"source": [
"# Define the dataset here\n",
"You will need to edit this if your task is different."
"\n",
"The outcome of this section should be the variable `dataset: list[tuple[str, str]]`"
]
},
{
Expand Down Expand Up @@ -212,17 +212,30 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "a1909080",
"id": "ebf69bf8",
"metadata": {},
"outputs": [],
"source": [
"# you can also just define the dataset with code\n",
"dataset = list(zip(df[\"question_text\"], df[\"target\"]))"
"# the other parts of the notebook will only take this variable as input\n",
"dataset: list[tuple[str, str]] = list(zip(df[\"question_text\"], df[\"target\"]))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0c7c1744",
"metadata": {},
"outputs": [],
"source": [
"for text, target in dataset:\n",
" # note that you need to cast the target into a string\n",
" assert type(text) == str\n",
" assert type(target) == str"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "53e4e290",
"metadata": {},
"outputs": [
Expand All @@ -232,20 +245,19 @@
"Counter({'0': 100, '1': 100})"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# make sure the number of types of labels is small\n",
"# prefer descriptive labels to avoid giving the model mental gymnastics\n",
"collections.Counter(label for _, label in dataset)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "b9380f08",
"metadata": {},
"outputs": [
Expand All @@ -256,13 +268,13 @@
" '0')"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[0] # should be tuple[string, label]"
"dataset[0]"
]
},
{
Expand All @@ -276,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "fdccd181",
"metadata": {},
"outputs": [],
Expand All @@ -294,7 +306,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "89fdc763",
"metadata": {},
"outputs": [],
Expand All @@ -303,7 +315,6 @@
"model_parameters = {\n",
" \"system_prompt\": \"\",\n",
" \"user_prompt_prefix\": \"\",\n",
" \"user_prompt_suffix\": \"\",\n",
"}\n",
"\n",
"token_counts = defaultdict(int)\n",
Expand All @@ -312,7 +323,7 @@
},
{
"cell_type": "markdown",
"id": "653db4c2",
"id": "4dc4cd7a",
"metadata": {},
"source": [
"# Model definition\n",
Expand All @@ -321,21 +332,23 @@
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3a099cae",
"execution_count": 14,
"id": "c037ad70",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def compute_model_response(text, model_parameters):\n",
" \n",
" user_message = model_parameters[\"user_prompt_prefix\"] + text + model_parameters[\"user_prompt_suffix\"]\n",
" user_message = model_parameters[\"user_prompt_prefix\"] + text\n",
" \n",
" message = anthropic_client.messages.create(\n",
" model=\"claude-3-haiku-20240307\",\n",
" max_tokens=2000,\n",
" temperature=0,\n",
" messages=[{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_message}]}],\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": user_message},\n",
" ],\n",
" system=model_parameters[\"system_prompt\"],\n",
" timeout=10\n",
" )\n",
Expand All @@ -357,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "94a927c8",
"metadata": {},
"outputs": [],
Expand All @@ -378,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "576c5225",
"metadata": {},
"outputs": [],
Expand All @@ -400,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "07a50c74",
"metadata": {},
"outputs": [],
Expand All @@ -422,13 +435,13 @@
" You will improve the prompts so that the LLM will predict the expected label.\n",
" You might need to guess what each label means.\n",
"\n",
" The LLM input has the following parameters. Make use of all the parameters.\n",
" The LLM input has the following parameters.\n",
" - {list(model_parameters.keys())}\n",
" \n",
" This is how the parameters are used by the LLM\n",
" {inspect.getsource(compute_model_response)}\n",
"\n",
" Please ensure that the prompt contains these following instructions\n",
" Please ensure that the prompts contains these following instructions\n",
" - The label should have the same exact text as the actual labels\n",
" - The information on what the set of labels are\n",
" - A small number of example inputs and outputs\n",
Expand Down Expand Up @@ -499,7 +512,8 @@
" \n",
" What are the proposed changes the prompt\n",
" \n",
" The prompts in the following format\n",
" The prompt parameters in the following format\n",
" (please make sure each prompt parameter has some meaningful content)\n",
" \"\"\") + \"\\n\\n\"\n",
"\n",
" for model_parameter_key in model_parameters.keys():\n",
Expand Down Expand Up @@ -541,7 +555,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "040e056a",
"metadata": {},
"outputs": [],
Expand All @@ -567,7 +581,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "04eaf8a7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -710,7 +724,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "b1647bb3",
"metadata": {
"scrolled": false
Expand Down Expand Up @@ -929,21 +943,21 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "9154d3ab",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(int,\n",
" {'haiku_input': 462669,\n",
" 'haiku_output': 76423,\n",
" 'o1_input': 22164,\n",
" 'o1_output': 9344})"
" {'haiku_input': 601309,\n",
" 'haiku_output': 90350,\n",
" 'o1_input': 24116,\n",
" 'o1_output': 9897})"
]
},
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -954,15 +968,15 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"id": "e509034e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cost: $0.42\n"
"Cost: $0.48\n"
]
}
],
Expand All @@ -973,7 +987,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ede29994",
"id": "8f4fd1ef",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Loading

0 comments on commit c9f9e3d

Please sign in to comment.