From 327acf1d1b4b41976f595eeb030b092a3578fcb0 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 11 Apr 2024 15:13:18 -0700 Subject: [PATCH] feat: Anthropic Claude API support (#1239) --- memgpt/cli/cli_config.py | 101 ++++++++- memgpt/credentials.py | 8 + memgpt/data_types.py | 72 ++++++ memgpt/llm_api/anthropic.py | 376 ++++++++++++++++++++++++++++++++ memgpt/llm_api/llm_api_tools.py | 30 ++- memgpt/local_llm/constants.py | 1 + 6 files changed, 581 insertions(+), 7 deletions(-) create mode 100644 memgpt/llm_api/anthropic.py diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 905ca244da..f8deb5715a 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -21,6 +21,8 @@ from memgpt.llm_api.openai import openai_get_model_list from memgpt.llm_api.azure_openai import azure_openai_get_model_list from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window +from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window +from memgpt.llm_api.llm_api_tools import LLM_API_PROVIDER_OPTIONS from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME from memgpt.local_llm.utils import get_available_wrappers from memgpt.server.utils import shorten_key_middle @@ -64,14 +66,14 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) # get default default_model_endpoint_type = config.default_llm_config.model_endpoint_type if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [ - "openai", - "azure", - "google_ai", + provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local" ]: # local model default_model_endpoint_type = "local" provider = questionary.select( - "Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type + "Select LLM inference provider:", + choices=LLM_API_PROVIDER_OPTIONS, + default=default_model_endpoint_type, ).ask() if provider is None: raise KeyboardInterrupt @@ -184,6 +186,46 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) model_endpoint_type = "google_ai" + elif provider == "anthropic": + # check for key + if credentials.anthropic_key is None: + # allow key to get pulled from env vars + anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", None) + # if we still can't find it, ask for it as input + if anthropic_api_key is None: + while anthropic_api_key is None or len(anthropic_api_key) == 0: + # Ask for API key as input + anthropic_api_key = questionary.password( + "Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):" + ).ask() + if anthropic_api_key is None: + raise KeyboardInterrupt + credentials.anthropic_key = anthropic_api_key + credentials.save() + else: + # Give the user an opportunity to overwrite the key + anthropic_api_key = None + default_input = ( + shorten_key_middle(credentials.anthropic_key) if credentials.anthropic_key.startswith("sk-") else credentials.anthropic_key + ) + anthropic_api_key = questionary.password( + "Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):", + default=default_input, + ).ask() + if anthropic_api_key is None: + raise KeyboardInterrupt + # If the user modified it, use the new one + if anthropic_api_key != default_input: + credentials.anthropic_key = anthropic_api_key + credentials.save() + + model_endpoint_type = "anthropic" + model_endpoint = "https://api.anthropic.com/v1" + model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() + if model_endpoint is None: + raise KeyboardInterrupt + provider = "anthropic" + else: # local models # backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] backend_options = builtins.list(DEFAULT_ENDPOINTS.keys()) @@ -291,6 +333,12 @@ def get_model_options( model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)] # model_options = ["gemini-pro"] + elif model_endpoint_type == "anthropic": + if credentials.anthropic_key is None: + raise ValueError("Missing Anthropic API key") + fetched_model_options = anthropic_get_model_list(url=model_endpoint, api_key=credentials.anthropic_key) + model_options = [obj["name"] for obj in fetched_model_options] + else: # Attempt to do OpenAI endpoint style model fetching # TODO support local auth with api-key header @@ -382,6 +430,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model is None: raise KeyboardInterrupt + elif model_endpoint_type == "anthropic": + try: + fetched_model_options = get_model_options( + credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint + ) + except Exception as e: + # NOTE: if this fails, it means the user's key is probably bad + typer.secho( + f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED + ) + raise e + + model = questionary.select( + "Select default model:", + choices=fetched_model_options, + default=fetched_model_options[0], + ).ask() + if model is None: + raise KeyboardInterrupt + else: # local models # ask about local auth @@ -522,8 +590,8 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ fetched_context_window, "custom", ] - except: - print(f"Failed to get model details for model '{model}' on Google AI API") + except Exception as e: + print(f"Failed to get model details for model '{model}' on Google AI API ({str(e)})") context_window_input = questionary.select( "Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):", @@ -533,6 +601,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if context_window_input is None: raise KeyboardInterrupt + elif model_endpoint_type == "anthropic": + try: + fetched_context_window = str( + antropic_get_model_context_window(url=model_endpoint, api_key=credentials.anthropic_key, model=model) + ) + print(f"Got context window {fetched_context_window} for model {model}") + context_length_options = [ + fetched_context_window, + "custom", + ] + except Exception as e: + print(f"Failed to get model details for model '{model}' ({str(e)})") + + context_window_input = questionary.select( + "Select your model's context window (see https://docs.anthropic.com/claude/docs/models-overview):", + choices=context_length_options, + default=context_length_options[0], + ).ask() + if context_window_input is None: + raise KeyboardInterrupt + else: # Ask the user to specify the context length diff --git a/memgpt/credentials.py b/memgpt/credentials.py index af998369f7..2f7637cbbe 100644 --- a/memgpt/credentials.py +++ b/memgpt/credentials.py @@ -32,6 +32,9 @@ class MemGPTCredentials: google_ai_key: Optional[str] = None google_ai_service_endpoint: Optional[str] = None + # anthropic config + anthropic_key: Optional[str] = None + # azure config azure_auth_type: str = "api_key" azure_key: Optional[str] = None @@ -77,6 +80,8 @@ def load(cls) -> "MemGPTCredentials": # gemini "google_ai_key": get_field(config, "google_ai", "key"), "google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"), + # anthropic + "anthropic_key": get_field(config, "anthropic", "key"), # open llm "openllm_auth_type": get_field(config, "openllm", "auth_type"), "openllm_key": get_field(config, "openllm", "key"), @@ -113,6 +118,9 @@ def save(self): set_field(config, "google_ai", "key", self.google_ai_key) set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint) + # anthropic + set_field(config, "anthropic", "key", self.anthropic_key) + # openllm config set_field(config, "openllm", "auth_type", self.openllm_auth_type) set_field(config, "openllm", "key", self.openllm_key) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index dba9b7f969..0c336cdc7c 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -294,6 +294,78 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: return openai_message + def to_anthropic_dict(self, inner_thoughts_xml_tag="thinking") -> dict: + # raise NotImplementedError + + def add_xml_tag(string: str, xml_tag: Optional[str]): + # NOTE: Anthropic docs recommends using tag when using CoT + tool use + return f"<{xml_tag}>{string} dict: """Go from Message class to Google AI REST message object diff --git a/memgpt/llm_api/anthropic.py b/memgpt/llm_api/anthropic.py new file mode 100644 index 0000000000..cbe70206d8 --- /dev/null +++ b/memgpt/llm_api/anthropic.py @@ -0,0 +1,376 @@ +import requests +import uuid +import json +import re +from typing import Union, Optional, List + +from memgpt.data_types import Message +from memgpt.models.chat_completion_response import ( + ChatCompletionResponse, + UsageStatistics, + Choice, + Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype + ToolCall, + FunctionCall, +) +from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool +from memgpt.utils import smart_urljoin, get_utc_time + +BASE_URL = "https://api.anthropic.com/v1" + + +# https://docs.anthropic.com/claude/docs/models-overview +# Sadly hardcoded +MODEL_LIST = [ + { + "name": "claude-3-opus-20240229", + "context_window": 200000, + }, + { + "name": "claude-3-sonnet-20240229", + "context_window": 200000, + }, + { + "name": "claude-3-haiku-20240307", + "context_window": 200000, + }, +] + +DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." + + +def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int: + for model_dict in anthropic_get_model_list(url=url, api_key=api_key): + if model_dict["name"] == model: + return model_dict["context_window"] + raise ValueError(f"Can't find model '{model}' in Anthropic model list") + + +def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: + """https://docs.anthropic.com/claude/docs/models-overview""" + + # NOTE: currently there is no GET /models, so we need to hardcode + return MODEL_LIST + + +def convert_tools_to_anthropic_format(tools: List[Tool], inner_thoughts_in_kwargs: Optional[bool] = True) -> List[dict]: + """See: https://docs.anthropic.com/claude/docs/tool-use + + OpenAI style: + "tools": [{ + "type": "function", + "function": { + "name": "find_movies", + "description": "find ....", + "parameters": { + "type": "object", + "properties": { + PARAM: { + "type": PARAM_TYPE, # eg "string" + "description": PARAM_DESCRIPTION, + }, + ... + }, + "required": List[str], + } + } + } + ] + + Anthropic style: + "tools": [{ + "name": "find_movies", + "description": "find ....", + "input_schema": { + "type": "object", + "properties": { + PARAM: { + "type": PARAM_TYPE, # eg "string" + "description": PARAM_DESCRIPTION, + }, + ... + }, + "required": List[str], + } + } + ] + + Two small differences: + - 1 level less of nesting + - "parameters" -> "input_schema" + """ + tools_dict_list = [] + for tool in tools: + tools_dict_list.append( + { + "name": tool.function.name, + "description": tool.function.description, + "input_schema": tool.function.parameters, + } + ) + return tools_dict_list + + +def merge_tool_results_into_user_messages(messages: List[dict]): + """Anthropic API doesn't allow role 'tool'->'user' sequences + + Example HTTP error: + messages: roles must alternate between "user" and "assistant", but found multiple "user" roles in a row + + From: https://docs.anthropic.com/claude/docs/tool-use + You may be familiar with other APIs that return tool use as separate from the model's primary output, + or which use a special-purpose tool or function message role. + In contrast, Anthropic's models and API are built around alternating user and assistant messages, + where each message is an array of rich content blocks: text, image, tool_use, and tool_result. + """ + + # TODO walk through the messages list + # When a dict (dict_A) with 'role' == 'user' is followed by a dict with 'role' == 'user' (dict B), do the following + # dict_A["content"] = dict_A["content"] + dict_B["content"] + + # The result should be a new merged_messages list that doesn't have any back-to-back dicts with 'role' == 'user' + merged_messages = [] + if not messages: + return merged_messages + + # Start with the first message in the list + current_message = messages[0] + + for next_message in messages[1:]: + if current_message["role"] == "user" and next_message["role"] == "user": + # Merge contents of the next user message into current one + current_content = ( + current_message["content"] + if isinstance(current_message["content"], list) + else [{"type": "text", "text": current_message["content"]}] + ) + next_content = ( + next_message["content"] + if isinstance(next_message["content"], list) + else [{"type": "text", "text": next_message["content"]}] + ) + merged_content = current_content + next_content + current_message["content"] = merged_content + else: + # Append the current message to result as it's complete + merged_messages.append(current_message) + # Move on to the next message + current_message = next_message + + # Append the last processed message to the result + merged_messages.append(current_message) + + return merged_messages + + +def remap_finish_reason(stop_reason: str) -> str: + """Remap Anthropic's 'stop_reason' to OpenAI 'finish_reason' + + OpenAI: 'stop', 'length', 'function_call', 'content_filter', null + see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api + + From: https://docs.anthropic.com/claude/reference/migrating-from-text-completions-to-messages#stop-reason + + Messages have a stop_reason of one of the following values: + "end_turn": The conversational turn ended naturally. + "stop_sequence": One of your specified custom stop sequences was generated. + "max_tokens": (unchanged) + + """ + if stop_reason == "end_turn": + return "stop" + elif stop_reason == "stop_sequence": + return "stop" + elif stop_reason == "max_tokens": + return "length" + elif stop_reason == "tool_use": + return "function_call" + else: + raise ValueError(f"Unexpected stop_reason: {stop_reason}") + + +def strip_xml_tags(string: str, tag: Optional[str]) -> str: + if tag is None: + return string + # Construct the regular expression pattern to find the start and end tags + tag_pattern = f"<{tag}.*?>|" + # Use the regular expression to replace the tags with an empty string + return re.sub(tag_pattern, "", string) + + +def convert_anthropic_response_to_chatcompletion( + response_json: dict, # REST response from Google AI API + inner_thoughts_xml_tag: Optional[str] = None, +) -> ChatCompletionResponse: + """ + Example response from Claude 3: + response.json = { + 'id': 'msg_01W1xg9hdRzbeN2CfZM7zD2w', + 'type': 'message', + 'role': 'assistant', + 'content': [ + { + 'type': 'text', + 'text': "Analyzing user login event. This is Chad's first + interaction with me. I will adjust my personality and rapport accordingly." + }, + { + 'type': + 'tool_use', + 'id': 'toolu_01Ka4AuCmfvxiidnBZuNfP1u', + 'name': 'core_memory_append', + 'input': { + 'name': 'human', + 'content': 'Chad is logging in for the first time. I will aim to build a warm + and welcoming rapport.', + 'request_heartbeat': True + } + } + ], + 'model': 'claude-3-haiku-20240307', + 'stop_reason': 'tool_use', + 'stop_sequence': None, + 'usage': { + 'input_tokens': 3305, + 'output_tokens': 141 + } + } + """ + prompt_tokens = response_json["usage"]["input_tokens"] + completion_tokens = response_json["usage"]["output_tokens"] + + finish_reason = remap_finish_reason(response_json["stop_reason"]) + + if isinstance(response_json["content"], list): + # inner mono + function call + # TODO relax asserts + assert len(response_json["content"]) == 2, response_json + assert response_json["content"][0]["type"] == "text", response_json + assert response_json["content"][1]["type"] == "tool_use", response_json + content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag) + tool_calls = [ + ToolCall( + id=response_json["content"][1]["id"], + type="function", + function=FunctionCall( + name=response_json["content"][1]["name"], + arguments=json.dumps(response_json["content"][1]["input"]), + ), + ) + ] + else: + # just inner mono + content = strip_xml_tags(string=response_json["content"], tag=inner_thoughts_xml_tag) + tool_calls = None + + assert response_json["role"] == "assistant", response_json + choice = Choice( + index=0, + finish_reason=finish_reason, + message=ChoiceMessage( + role=response_json["role"], + content=content, + tool_calls=tool_calls, + ), + ) + + return ChatCompletionResponse( + id=response_json["id"], + choices=[choice], + created=get_utc_time(), + model=response_json["model"], + usage=UsageStatistics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +def anthropic_chat_completions_request( + url: str, + api_key: str, + data: ChatCompletionRequest, + inner_thoughts_xml_tag: Optional[str] = "thinking", +) -> ChatCompletionResponse: + """https://docs.anthropic.com/claude/docs/tool-use""" + from memgpt.utils import printd + + url = smart_urljoin(url, "messages") + headers = { + "Content-Type": "application/json", + "x-api-key": api_key, + # NOTE: beta headers for tool calling + "anthropic-version": "2023-06-01", + "anthropic-beta": "tools-2024-04-04", + } + + # convert the tools + anthropic_tools = None if data.tools is None else convert_tools_to_anthropic_format(data.tools) + + # pydantic -> dict + data = data.model_dump(exclude_none=True) + + if "functions" in data: + raise ValueError(f"'functions' unexpected in Anthropic API payload") + + # If tools == None, strip from the payload + if "tools" in data and data["tools"] is None: + data.pop("tools") + data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + # Remap to our converted tools + if anthropic_tools is not None: + data["tools"] = anthropic_tools + + # Move 'system' to the top level + # 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.' + assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}" + data["system"] = data["messages"][0]["content"] + data["messages"] = data["messages"][1:] + + # Convert to Anthropic format + msg_objs = [Message.dict_to_message(user_id=uuid.uuid4(), agent_id=uuid.uuid4(), openai_message_dict=m) for m in data["messages"]] + data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs] + + # Handling Anthropic special requirement for 'user' message in front + # messages: first message must use the "user" role' + if data["messages"][0]["role"] != "user": + data["messages"] = [{"role": "user", "content": DUMMY_FIRST_USER_MESSAGE}] + data["messages"] + + # Handle Anthropic's restriction on alternating user/assistant messages + data["messages"] = merge_tool_results_into_user_messages(data["messages"]) + + # Anthropic also wants max_tokens in the input + # It's also part of ChatCompletions + assert "max_tokens" in data, data + + # Remove extra fields used by OpenAI but not Anthropic + data.pop("frequency_penalty", None) + data.pop("logprobs", None) + data.pop("n", None) + data.pop("top_p", None) + data.pop("presence_penalty", None) + data.pop("user", None) + data.pop("tool_choice", None) + + printd(f"Sending request to {url}") + try: + response = requests.post(url, headers=headers, json=data) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + printd(f"response.json = {response}") + response = convert_anthropic_response_to_chatcompletion(response_json=response, inner_thoughts_xml_tag=inner_thoughts_xml_tag) + return response + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}, payload={data}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index c7824590cd..12ac6b0245 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -19,6 +19,10 @@ google_ai_chat_completions_request, convert_tools_to_google_ai_format, ) +from memgpt.llm_api.anthropic import anthropic_chat_completions_request + + +LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "local"] def is_context_overflow_error(exception: requests.exceptions.RequestException) -> bool: @@ -214,7 +218,7 @@ def create( if functions is not None: tools = [{"type": "function", "function": f} for f in functions] tools = [Tool(**t) for t in tools] - tools = (convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg),) + tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg) else: tools = None @@ -230,6 +234,30 @@ def create( ), ) + elif agent_state.llm_config.model_endpoint_type == "anthropic": + if not use_tool_naming: + raise NotImplementedError("Only tool calling supported on Anthropic API requests") + + if functions is not None: + tools = [{"type": "function", "function": f} for f in functions] + tools = [Tool(**t) for t in tools] + else: + tools = None + + return anthropic_chat_completions_request( + url=agent_state.llm_config.model_endpoint, + api_key=credentials.anthropic_key, + data=ChatCompletionRequest( + model=agent_state.llm_config.model, + messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], + tools=[{"type": "function", "function": f} for f in functions] if functions else None, + # tool_choice=function_call, + # user=str(agent_state.user_id), + # NOTE: max_tokens is required for Anthropic API + max_tokens=1024, # TODO make dynamic + ), + ) + # local model else: return get_chat_completion( diff --git a/memgpt/local_llm/constants.py b/memgpt/local_llm/constants.py index d4d4f81f35..b47c0a3be9 100644 --- a/memgpt/local_llm/constants.py +++ b/memgpt/local_llm/constants.py @@ -13,6 +13,7 @@ "vllm": "http://localhost:8000", # APIs "openai": "https://api.openai.com", + "anthropic": "https://api.anthropic.com", "groq": "https://api.groq.com/openai", }