Skip to content

Commit

Permalink
fix: various patches for Azure support + strip Box (#982)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Feb 9, 2024
1 parent ff57703 commit 0206312
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 18 deletions.
15 changes: 11 additions & 4 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
)
else:
credentials.azure_key = azure_creds["azure_key"]
credentials.azure_endpoint = azure_creds["azure_endpoint"]
credentials.azure_version = azure_creds["azure_version"]
config.save()
credentials.azure_embedding_version = azure_creds["azure_embedding_version"]
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
if "azure_embedding_deployment" in azure_creds:
credentials.azure_embedding_deployment = azure_creds["azure_embedding_deployment"]
credentials.save()

model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
Expand Down Expand Up @@ -417,7 +419,12 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them
credentials.azure_key = azure_creds["azure_key"]
credentials.azure_version = azure_creds["azure_version"]
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
if "azure_deployment" in azure_creds:
credentials.azure_deployment = azure_creds["azure_deployment"]
credentials.save()

embedding_endpoint_type = "azure"
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
Expand Down
14 changes: 11 additions & 3 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ class MemGPTCredentials:
# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
azure_endpoint: Optional[str] = None
# base llm / model
azure_version: Optional[str] = None
azure_endpoint: Optional[str] = None
azure_deployment: Optional[str] = None
# embeddings
azure_embedding_version: Optional[str] = None
azure_embedding_endpoint: Optional[str] = None
azure_embedding_deployment: Optional[str] = None

# custom llm API config
Expand All @@ -63,9 +67,11 @@ def load(cls) -> "MemGPTCredentials":
# azure
"azure_auth_type": get_field(config, "azure", "auth_type"),
"azure_key": get_field(config, "azure", "key"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_version": get_field(config, "azure", "version"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_version": get_field(config, "azure", "embedding_version"),
"azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
Expand All @@ -92,9 +98,11 @@ def save(self):
# azure config
set_field(config, "azure", "auth_type", self.azure_auth_type)
set_field(config, "azure", "key", self.azure_key)
set_field(config, "azure", "endpoint", self.azure_endpoint)
set_field(config, "azure", "version", self.azure_version)
set_field(config, "azure", "endpoint", self.azure_endpoint)
set_field(config, "azure", "deployment", self.azure_deployment)
set_field(config, "azure", "embedding_version", self.azure_embedding_version)
set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)

# openai config
Expand Down
16 changes: 11 additions & 5 deletions memgpt/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from typing import Callable, TypeVar, Union
import urllib

from box import Box

from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.embedding_response import EmbeddingResponse

from memgpt.data_types import AgentState

Expand Down Expand Up @@ -74,6 +73,8 @@ def smart_urljoin(base_url, relative_url):

def clean_azure_endpoint(raw_endpoint_name):
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
if raw_endpoint_name is None:
raise ValueError(raw_endpoint_name)
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
endpoint_address = endpoint_address.replace("http://", "")
endpoint_address = endpoint_address.replace("https://", "")
Expand Down Expand Up @@ -231,7 +232,7 @@ def openai_embeddings_request(url, api_key, data):
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand All @@ -251,6 +252,11 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd

assert resource_name is not None, "Missing required field when calling Azure OpenAI"
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
assert api_version is not None, "Missing required field when calling Azure OpenAI"
assert api_key is not None, "Missing required field when calling Azure OpenAI"

resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
Expand All @@ -274,7 +280,7 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
if "content" not in response["choices"][0].get("message"):
response["choices"][0]["message"]["content"] = None
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand Down Expand Up @@ -305,7 +311,7 @@ def azure_openai_embeddings_request(resource_name, deployment_id, api_version, a
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand Down
3 changes: 0 additions & 3 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""

import os
from datetime import datetime
import requests
import json
import uuid

from box import Box

from memgpt.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
from memgpt.local_llm.webui.api import get_webui_completion
from memgpt.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
Expand Down
10 changes: 10 additions & 0 deletions memgpt/models/embedding_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List, Literal
from pydantic import BaseModel


class EmbeddingResponse(BaseModel):
"""OpenAI embedding response model: https://platform.openai.com/docs/api-reference/embeddings/object"""

index: int # the index of the embedding in the list of embeddings
embedding: List[float]
object: Literal["embedding"] = "embedding"
13 changes: 10 additions & 3 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,19 @@ def verify_first_message_correctness(
response_message = response.choices[0].message

# First message should be a call to send_message with a non-empty content
if require_send_message and not (response_message.function_call or response_message.tool_calls):
if ("function_call" in response_message and response_message.function_call is not None) and (
"tool_calls" in response_message and response_message.tool_calls is not None
):
printd(f"First message includes both function call AND tool call: {response_message}")
return False
elif "function_call" in response_message and response_message.function_call is not None:
function_call = response_message.function_call
elif "tool_calls" in response_message and response_message.tool_calls is not None:
function_call = response_message.tool_calls[0].function
else:
printd(f"First message didn't include function call: {response_message}")
return False

assert not (response_message.function_call and response_message.tool_calls), response_message
function_call = response_message.function_call if response_message.function_call else response_message.tool_calls[0].function
function_name = function_call.name if function_call is not None else ""
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
Expand Down

0 comments on commit 0206312

Please sign in to comment.