Skip to content

Commit

Permalink
feat: Anthropic Claude API support (#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Apr 11, 2024
1 parent 9ffa003 commit 327acf1
Show file tree
Hide file tree
Showing 6 changed files with 581 additions and 7 deletions.
101 changes: 95 additions & 6 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from memgpt.llm_api.openai import openai_get_model_list
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window
from memgpt.llm_api.llm_api_tools import LLM_API_PROVIDER_OPTIONS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.server.utils import shorten_key_middle
Expand Down Expand Up @@ -64,14 +66,14 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
# get default
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
"openai",
"azure",
"google_ai",
provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
]: # local model
default_model_endpoint_type = "local"

provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type
"Select LLM inference provider:",
choices=LLM_API_PROVIDER_OPTIONS,
default=default_model_endpoint_type,
).ask()
if provider is None:
raise KeyboardInterrupt
Expand Down Expand Up @@ -184,6 +186,46 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)

model_endpoint_type = "google_ai"

elif provider == "anthropic":
# check for key
if credentials.anthropic_key is None:
# allow key to get pulled from env vars
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", None)
# if we still can't find it, ask for it as input
if anthropic_api_key is None:
while anthropic_api_key is None or len(anthropic_api_key) == 0:
# Ask for API key as input
anthropic_api_key = questionary.password(
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):"
).ask()
if anthropic_api_key is None:
raise KeyboardInterrupt
credentials.anthropic_key = anthropic_api_key
credentials.save()
else:
# Give the user an opportunity to overwrite the key
anthropic_api_key = None
default_input = (
shorten_key_middle(credentials.anthropic_key) if credentials.anthropic_key.startswith("sk-") else credentials.anthropic_key
)
anthropic_api_key = questionary.password(
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):",
default=default_input,
).ask()
if anthropic_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if anthropic_api_key != default_input:
credentials.anthropic_key = anthropic_api_key
credentials.save()

model_endpoint_type = "anthropic"
model_endpoint = "https://api.anthropic.com/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = "anthropic"

else: # local models
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
Expand Down Expand Up @@ -291,6 +333,12 @@ def get_model_options(
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
# model_options = ["gemini-pro"]

elif model_endpoint_type == "anthropic":
if credentials.anthropic_key is None:
raise ValueError("Missing Anthropic API key")
fetched_model_options = anthropic_get_model_list(url=model_endpoint, api_key=credentials.anthropic_key)
model_options = [obj["name"] for obj in fetched_model_options]

else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
Expand Down Expand Up @@ -382,6 +430,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "anthropic":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e

model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

else: # local models

# ask about local auth
Expand Down Expand Up @@ -522,8 +590,8 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
fetched_context_window,
"custom",
]
except:
print(f"Failed to get model details for model '{model}' on Google AI API")
except Exception as e:
print(f"Failed to get model details for model '{model}' on Google AI API ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):",
Expand All @@ -533,6 +601,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if context_window_input is None:
raise KeyboardInterrupt

elif model_endpoint_type == "anthropic":
try:
fetched_context_window = str(
antropic_get_model_context_window(url=model_endpoint, api_key=credentials.anthropic_key, model=model)
)
print(f"Got context window {fetched_context_window} for model {model}")
context_length_options = [
fetched_context_window,
"custom",
]
except Exception as e:
print(f"Failed to get model details for model '{model}' ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://docs.anthropic.com/claude/docs/models-overview):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt

else:

# Ask the user to specify the context length
Expand Down
8 changes: 8 additions & 0 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class MemGPTCredentials:
google_ai_key: Optional[str] = None
google_ai_service_endpoint: Optional[str] = None

# anthropic config
anthropic_key: Optional[str] = None

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
Expand Down Expand Up @@ -77,6 +80,8 @@ def load(cls) -> "MemGPTCredentials":
# gemini
"google_ai_key": get_field(config, "google_ai", "key"),
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# anthropic
"anthropic_key": get_field(config, "anthropic", "key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down Expand Up @@ -113,6 +118,9 @@ def save(self):
set_field(config, "google_ai", "key", self.google_ai_key)
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)

# anthropic
set_field(config, "anthropic", "key", self.anthropic_key)

# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)
Expand Down
72 changes: 72 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,78 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:

return openai_message

def to_anthropic_dict(self, inner_thoughts_xml_tag="thinking") -> dict:
# raise NotImplementedError

def add_xml_tag(string: str, xml_tag: Optional[str]):
# NOTE: Anthropic docs recommends using <thinking> tag when using CoT + tool use
return f"<{xml_tag}>{string}</{xml_tag}" if xml_tag else string

if self.role == "system":
raise ValueError(f"Anthropic 'system' role not supported")

elif self.role == "user":
assert all([v is not None for v in [self.text, self.role]]), vars(self)
anthropic_message = {
"content": self.text,
"role": self.role,
}
# Optional field, do not include if null
if self.name is not None:
anthropic_message["name"] = self.name

elif self.role == "assistant":
assert self.tool_calls is not None or self.text is not None
anthropic_message = {
"role": self.role,
}
content = []
if self.text is not None:
content.append(
{
"type": "text",
"text": add_xml_tag(string=self.text, xml_tag=inner_thoughts_xml_tag),
}
)
if self.tool_calls is not None:
for tool_call in self.tool_calls:
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function["name"],
"input": json.loads(tool_call.function["arguments"]),
}
)

# If the only content was text, unpack it back into a singleton
# TODO
anthropic_message["content"] = content

# Optional fields, do not include if null
if self.name is not None:
anthropic_message["name"] = self.name

elif self.role == "tool":
# NOTE: Anthropic uses role "user" for "tool" responses
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
anthropic_message = {
"role": "user", # NOTE: diff
"content": [
# TODO support error types etc
{
"type": "tool_result",
"tool_use_id": self.tool_call_id,
"content": self.text,
}
],
}

else:
raise ValueError(self.role)

return anthropic_message

def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
"""Go from Message class to Google AI REST message object
Expand Down
Loading

0 comments on commit 327acf1

Please sign in to comment.