Skip to content

Commit

Permalink
use o1
Browse files Browse the repository at this point in the history
  • Loading branch information
Hui Kang Tong committed Sep 16, 2024
1 parent 289f35b commit 156d22b
Show file tree
Hide file tree
Showing 12 changed files with 6,420 additions and 6,242 deletions.
156 changes: 97 additions & 59 deletions classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,44 @@
"- the user prompt suffix\n",
"\n",
"You can simply run this notebook with just\n",
"- an Anthropic API key\n",
"- an Anthropic API key and an OpenAI API key\n",
"\n",
"If you want to change the classification task, you will need to\n",
"- provide a dataset (text -> label)\n",
"- define the function bot_response -> label\n",
"- description for Opus on what instructions Haiku should follow\n",
"- description for o1 on what instructions Haiku should follow\n",
"\n",
"This is how prompt tuning is done\n",
"- Sample from the full dataset.\n",
"- Haiku takes in (system prompt, user prompt prefix + text + user prompt suffix) and produces bot_response (the final layer values).\n",
"- The function takes in bot_response and produces the label. The (text -> label) process is analogous to the forward pass.\n",
"- Sample from the mistakes.\n",
"- Opus takes in the mistakes and summarizes the mistakes (gradient calculation).\n",
"- Opus takes in the mistake summary (gradient) and the current prompts (model parameters) update the prompts.\n",
"- o1-mini takes in the mistakes and summarizes the mistakes (gradient calculation).\n",
"- o1-mini takes in the mistake summary (gradient) and the current prompts (model parameters) update the prompts.\n",
"- Repeat.\n",
"\n",
"You will need to have these Python modules installed\n",
"- pandas\n",
"- scikit-learn\n",
"- anthropic"
"- anthropic\n",
"- openai"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3c074e16",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Caskroom/miniforge/base/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
}
],
"source": [
"import os\n",
"import re\n",
Expand All @@ -59,7 +69,8 @@
"from IPython.display import display, HTML\n",
"from sklearn.metrics import precision_score, recall_score\n",
"\n",
"import anthropic"
"import anthropic\n",
"from openai import OpenAI"
]
},
{
Expand All @@ -79,7 +90,10 @@
"source": [
"anthropic_api_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n",
"# anthropic_api_key = \"sk-ant-\"\n",
"client = anthropic.Anthropic(api_key=anthropic_api_key)"
"client = anthropic.Anthropic(api_key=anthropic_api_key)\n",
"\n",
"openai_api_key = os.environ.get(\"OPENAI_API_KEY\")\n",
"openai_client = OpenAI(api_key=openai_api_key)"
]
},
{
Expand All @@ -103,6 +117,24 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "2651019b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sk-T0uvUYsHoQOag\n"
]
}
],
"source": [
"print(openai_api_key[:16])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6130e282",
"metadata": {},
"outputs": [],
Expand All @@ -125,7 +157,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "5ba53aef",
"metadata": {},
"outputs": [
Expand All @@ -137,7 +169,7 @@
"Name: target, dtype: int64"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -150,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "e1df0784",
"metadata": {},
"outputs": [],
Expand All @@ -166,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "53e4e290",
"metadata": {},
"outputs": [
Expand All @@ -176,7 +208,7 @@
"Counter({'insincere': 100, 'sincere': 100})"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -189,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "b9380f08",
"metadata": {},
"outputs": [
Expand All @@ -200,7 +232,7 @@
" 'insincere')"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -220,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "fdccd181",
"metadata": {},
"outputs": [],
Expand All @@ -233,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "334c7221",
"metadata": {},
"outputs": [],
Expand All @@ -259,7 +291,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "89fdc763",
"metadata": {},
"outputs": [],
Expand All @@ -278,6 +310,8 @@
" \"haiku_output\": 0,\n",
" \"sonnet_output\": 0,\n",
" \"opus_output\": 0,\n",
" \"o1_input\": 0,\n",
" \"o1_output\": 0,\n",
"} # ideally this should have been tracked in anthropic.Anthropic"
]
},
Expand All @@ -293,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "94a927c8",
"metadata": {},
"outputs": [],
Expand All @@ -314,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "b4cbd985",
"metadata": {},
"outputs": [],
Expand All @@ -339,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "576c5225",
"metadata": {},
"outputs": [],
Expand All @@ -361,7 +395,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "4f701285",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -405,31 +439,29 @@
" \"\"\"\n",
" )\n",
" \n",
" message = client.messages.create(\n",
" model=\"claude-3-opus-20240229\",\n",
" max_tokens=2000,\n",
" temperature=0,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_message}]}]\n",
" )\n",
" token_counts[\"opus_input\"] += message.usage.input_tokens\n",
" token_counts[\"opus_output\"] += message.usage.output_tokens\n",
" response = openai_client.chat.completions.create(\n",
" model=\"o1-mini\",\n",
" messages=[{\"role\": \"user\", \"content\": user_message}]\n",
" ) \n",
" \n",
" token_counts[\"o1_input\"] += response.usage.prompt_tokens\n",
" token_counts[\"o1_output\"] += response.usage.completion_tokens\n",
"\n",
" return message.content[0].text"
" return response.choices[0].message.content"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "07a50c74",
"metadata": {},
"outputs": [],
"source": [
"def update_model_parameters(gradient, model_parameters, metrics):\n",
"\n",
" system_message = PROMPT_UPDATE_SYSTEM_PROMPT\n",
"\n",
" user_message = textwrap.dedent(f\"\"\" \n",
" user_message = textwrap.dedent(f\"\"\"\n",
" {PROMPT_UPDATE_SYSTEM_PROMPT}\n",
" \n",
" The current metrics is {str(metrics)}\n",
"\n",
" This the current set of prompts\n",
Expand Down Expand Up @@ -465,17 +497,15 @@
" </user_prompt_suffix>\n",
" \"\"\")\n",
" \n",
" message = client.messages.create(\n",
" model=\"claude-3-opus-20240229\",\n",
" max_tokens=2000,\n",
" temperature=0,\n",
" system=system_message,\n",
" messages=[{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": user_message}]}],\n",
" response = openai_client.chat.completions.create(\n",
" model=\"o1-mini\",\n",
" messages=[{\"role\": \"user\", \"content\": user_message}]\n",
" )\n",
" token_counts[\"opus_input\"] += message.usage.input_tokens\n",
" token_counts[\"opus_output\"] += message.usage.output_tokens\n",
"\n",
" bot_message = message.content[0].text\n",
" token_counts[\"o1_input\"] += response.usage.prompt_tokens\n",
" token_counts[\"o1_output\"] += response.usage.completion_tokens\n",
"\n",
" bot_message = response.choices[0].message.content\n",
"\n",
" match_system_prompt = re.search(r'<system_prompt>(.*?)</system_prompt>', bot_message, re.DOTALL)\n",
" match_user_prompt_prefix = re.search(r'<user_prompt_prefix>(.*?)</user_prompt_prefix>', bot_message, re.DOTALL)\n",
Expand All @@ -500,7 +530,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "040e056a",
"metadata": {},
"outputs": [],
Expand All @@ -522,7 +552,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "04eaf8a7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -623,7 +653,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "b1647bb3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -839,22 +869,24 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "9154d3ab",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'haiku_input': 335600,\n",
"{'haiku_input': 287901,\n",
" 'sonnet_input': 0,\n",
" 'opus_input': 49812,\n",
" 'haiku_output': 211791,\n",
" 'opus_input': 0,\n",
" 'haiku_output': 186151,\n",
" 'sonnet_output': 0,\n",
" 'opus_output': 5612}"
" 'opus_output': 0,\n",
" 'o1_input': 48603,\n",
" 'o1_output': 24599}"
]
},
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -865,25 +897,31 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"id": "e509034e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.51671875"
"0.745661"
]
},
"execution_count": 21,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cost_in_dollar = (\n",
" token_counts[\"haiku_input\"] * 0.25 + token_counts[\"sonnet_input\"] * 3 + token_counts[\"opus_input\"] * 15\n",
" + token_counts[\"haiku_output\"] * 1.25 + token_counts[\"sonnet_output\"] * 15 + token_counts[\"opus_output\"] * 75\n",
" token_counts[\"haiku_input\"] * 0.25 \n",
" + token_counts[\"sonnet_input\"] * 3 \n",
" + token_counts[\"opus_input\"] * 15\n",
" + token_counts[\"o1_input\"] * 3\n",
" + token_counts[\"haiku_output\"] * 1.25 \n",
" + token_counts[\"sonnet_output\"] * 15 \n",
" + token_counts[\"opus_output\"] * 7\n",
" + token_counts[\"o1_output\"] * 12\n",
") / 1_000_000\n",
"cost_in_dollar"
]
Expand Down
Loading

0 comments on commit 156d22b

Please sign in to comment.