From 0206312aca1eeb3461862eaacbaa910411376b64 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Fri, 9 Feb 2024 15:01:57 -0800 Subject: [PATCH] fix: various patches for Azure support + strip `Box` (#982) --- memgpt/cli/cli_config.py | 15 +++++++++++---- memgpt/credentials.py | 14 +++++++++++--- memgpt/llm_api_tools.py | 16 +++++++++++----- memgpt/local_llm/chat_completion_proxy.py | 3 --- memgpt/models/embedding_response.py | 10 ++++++++++ memgpt/utils.py | 13 ++++++++++--- 6 files changed, 53 insertions(+), 18 deletions(-) create mode 100644 memgpt/models/embedding_response.py diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 22b3f81ed8..6e8aaf898f 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -117,9 +117,11 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) ) else: credentials.azure_key = azure_creds["azure_key"] - credentials.azure_endpoint = azure_creds["azure_endpoint"] - credentials.azure_version = azure_creds["azure_version"] - config.save() + credentials.azure_embedding_version = azure_creds["azure_embedding_version"] + credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"] + if "azure_embedding_deployment" in azure_creds: + credentials.azure_embedding_deployment = azure_creds["azure_embedding_deployment"] + credentials.save() model_endpoint_type = "azure" model_endpoint = azure_creds["azure_endpoint"] @@ -417,7 +419,12 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden raise ValueError( "Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again." ) - # TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them + credentials.azure_key = azure_creds["azure_key"] + credentials.azure_version = azure_creds["azure_version"] + credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"] + if "azure_deployment" in azure_creds: + credentials.azure_deployment = azure_creds["azure_deployment"] + credentials.save() embedding_endpoint_type = "azure" embedding_endpoint = azure_creds["azure_embedding_endpoint"] diff --git a/memgpt/credentials.py b/memgpt/credentials.py index 746b0a705d..853ea38e7c 100644 --- a/memgpt/credentials.py +++ b/memgpt/credentials.py @@ -34,9 +34,13 @@ class MemGPTCredentials: # azure config azure_auth_type: str = "api_key" azure_key: Optional[str] = None - azure_endpoint: Optional[str] = None + # base llm / model azure_version: Optional[str] = None + azure_endpoint: Optional[str] = None azure_deployment: Optional[str] = None + # embeddings + azure_embedding_version: Optional[str] = None + azure_embedding_endpoint: Optional[str] = None azure_embedding_deployment: Optional[str] = None # custom llm API config @@ -63,9 +67,11 @@ def load(cls) -> "MemGPTCredentials": # azure "azure_auth_type": get_field(config, "azure", "auth_type"), "azure_key": get_field(config, "azure", "key"), - "azure_endpoint": get_field(config, "azure", "endpoint"), "azure_version": get_field(config, "azure", "version"), + "azure_endpoint": get_field(config, "azure", "endpoint"), "azure_deployment": get_field(config, "azure", "deployment"), + "azure_embedding_version": get_field(config, "azure", "embedding_version"), + "azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"), "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), # open llm "openllm_auth_type": get_field(config, "openllm", "auth_type"), @@ -92,9 +98,11 @@ def save(self): # azure config set_field(config, "azure", "auth_type", self.azure_auth_type) set_field(config, "azure", "key", self.azure_key) - set_field(config, "azure", "endpoint", self.azure_endpoint) set_field(config, "azure", "version", self.azure_version) + set_field(config, "azure", "endpoint", self.azure_endpoint) set_field(config, "azure", "deployment", self.azure_deployment) + set_field(config, "azure", "embedding_version", self.azure_embedding_version) + set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint) set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment) # openai config diff --git a/memgpt/llm_api_tools.py b/memgpt/llm_api_tools.py index e507e38d81..3fa8d7ff1b 100644 --- a/memgpt/llm_api_tools.py +++ b/memgpt/llm_api_tools.py @@ -5,12 +5,11 @@ from typing import Callable, TypeVar, Union import urllib -from box import Box - from memgpt.credentials import MemGPTCredentials from memgpt.local_llm.chat_completion_proxy import get_chat_completion from memgpt.constants import CLI_WARNING_PREFIX from memgpt.models.chat_completion_response import ChatCompletionResponse +from memgpt.models.embedding_response import EmbeddingResponse from memgpt.data_types import AgentState @@ -74,6 +73,8 @@ def smart_urljoin(base_url, relative_url): def clean_azure_endpoint(raw_endpoint_name): """Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'""" + if raw_endpoint_name is None: + raise ValueError(raw_endpoint_name) endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "") endpoint_address = endpoint_address.replace("http://", "") endpoint_address = endpoint_address.replace("https://", "") @@ -231,7 +232,7 @@ def openai_embeddings_request(url, api_key, data): response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string printd(f"response.json = {response}") - response = Box(response) # convert to 'dot-dict' style which is the openai python client default + response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) @@ -251,6 +252,11 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers """https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""" from memgpt.utils import printd + assert resource_name is not None, "Missing required field when calling Azure OpenAI" + assert deployment_id is not None, "Missing required field when calling Azure OpenAI" + assert api_version is not None, "Missing required field when calling Azure OpenAI" + assert api_key is not None, "Missing required field when calling Azure OpenAI" + resource_name = clean_azure_endpoint(resource_name) url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}" headers = {"Content-Type": "application/json", "api-key": f"{api_key}"} @@ -274,7 +280,7 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers # NOTE: azure openai does not include "content" in the response when it is None, so we need to add it if "content" not in response["choices"][0].get("message"): response["choices"][0]["message"]["content"] = None - response = Box(response) # convert to 'dot-dict' style which is the openai python client default + response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) @@ -305,7 +311,7 @@ def azure_openai_embeddings_request(resource_name, deployment_id, api_version, a response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string printd(f"response.json = {response}") - response = Box(response) # convert to 'dot-dict' style which is the openai python client default + response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default return response except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index fc08fddf38..1659b6ee55 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -1,13 +1,10 @@ """Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend""" -import os from datetime import datetime import requests import json import uuid -from box import Box - from memgpt.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation from memgpt.local_llm.webui.api import get_webui_completion from memgpt.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy diff --git a/memgpt/models/embedding_response.py b/memgpt/models/embedding_response.py new file mode 100644 index 0000000000..fb4664872d --- /dev/null +++ b/memgpt/models/embedding_response.py @@ -0,0 +1,10 @@ +from typing import List, Literal +from pydantic import BaseModel + + +class EmbeddingResponse(BaseModel): + """OpenAI embedding response model: https://platform.openai.com/docs/api-reference/embeddings/object""" + + index: int # the index of the embedding in the list of embeddings + embedding: List[float] + object: Literal["embedding"] = "embedding" diff --git a/memgpt/utils.py b/memgpt/utils.py index 5931147cbd..78dd7f91f3 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -668,12 +668,19 @@ def verify_first_message_correctness( response_message = response.choices[0].message # First message should be a call to send_message with a non-empty content - if require_send_message and not (response_message.function_call or response_message.tool_calls): + if ("function_call" in response_message and response_message.function_call is not None) and ( + "tool_calls" in response_message and response_message.tool_calls is not None + ): + printd(f"First message includes both function call AND tool call: {response_message}") + return False + elif "function_call" in response_message and response_message.function_call is not None: + function_call = response_message.function_call + elif "tool_calls" in response_message and response_message.tool_calls is not None: + function_call = response_message.tool_calls[0].function + else: printd(f"First message didn't include function call: {response_message}") return False - assert not (response_message.function_call and response_message.tool_calls), response_message - function_call = response_message.function_call if response_message.function_call else response_message.tool_calls[0].function function_name = function_call.name if function_call is not None else "" if require_send_message and function_name != "send_message" and function_name != "archival_memory_search": printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")