Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: various patches for Azure support + strip Box #982

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
)
else:
credentials.azure_key = azure_creds["azure_key"]
credentials.azure_endpoint = azure_creds["azure_endpoint"]
credentials.azure_version = azure_creds["azure_version"]
config.save()
credentials.azure_embedding_version = azure_creds["azure_embedding_version"]
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
if "azure_embedding_deployment" in azure_creds:
credentials.azure_embedding_deployment = azure_creds["azure_embedding_deployment"]
credentials.save()

model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
Expand Down Expand Up @@ -417,7 +419,12 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them
credentials.azure_key = azure_creds["azure_key"]
credentials.azure_version = azure_creds["azure_version"]
credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"]
if "azure_deployment" in azure_creds:
credentials.azure_deployment = azure_creds["azure_deployment"]
credentials.save()

embedding_endpoint_type = "azure"
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
Expand Down
14 changes: 11 additions & 3 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ class MemGPTCredentials:
# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
azure_endpoint: Optional[str] = None
# base llm / model
azure_version: Optional[str] = None
azure_endpoint: Optional[str] = None
azure_deployment: Optional[str] = None
# embeddings
azure_embedding_version: Optional[str] = None
azure_embedding_endpoint: Optional[str] = None
azure_embedding_deployment: Optional[str] = None

# custom llm API config
Expand All @@ -63,9 +67,11 @@ def load(cls) -> "MemGPTCredentials":
# azure
"azure_auth_type": get_field(config, "azure", "auth_type"),
"azure_key": get_field(config, "azure", "key"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_version": get_field(config, "azure", "version"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_version": get_field(config, "azure", "embedding_version"),
"azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
Expand All @@ -92,9 +98,11 @@ def save(self):
# azure config
set_field(config, "azure", "auth_type", self.azure_auth_type)
set_field(config, "azure", "key", self.azure_key)
set_field(config, "azure", "endpoint", self.azure_endpoint)
set_field(config, "azure", "version", self.azure_version)
set_field(config, "azure", "endpoint", self.azure_endpoint)
set_field(config, "azure", "deployment", self.azure_deployment)
set_field(config, "azure", "embedding_version", self.azure_embedding_version)
set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)

# openai config
Expand Down
16 changes: 11 additions & 5 deletions memgpt/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from typing import Callable, TypeVar, Union
import urllib

from box import Box

from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.embedding_response import EmbeddingResponse

from memgpt.data_types import AgentState

Expand Down Expand Up @@ -74,6 +73,8 @@ def smart_urljoin(base_url, relative_url):

def clean_azure_endpoint(raw_endpoint_name):
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
if raw_endpoint_name is None:
raise ValueError(raw_endpoint_name)
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
endpoint_address = endpoint_address.replace("http://", "")
endpoint_address = endpoint_address.replace("https://", "")
Expand Down Expand Up @@ -231,7 +232,7 @@ def openai_embeddings_request(url, api_key, data):
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand All @@ -251,6 +252,11 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd

assert resource_name is not None, "Missing required field when calling Azure OpenAI"
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
assert api_version is not None, "Missing required field when calling Azure OpenAI"
assert api_key is not None, "Missing required field when calling Azure OpenAI"

resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
Expand All @@ -274,7 +280,7 @@ def azure_openai_chat_completions_request(resource_name, deployment_id, api_vers
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
if "content" not in response["choices"][0].get("message"):
response["choices"][0]["message"]["content"] = None
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand Down Expand Up @@ -305,7 +311,7 @@ def azure_openai_embeddings_request(resource_name, deployment_id, api_version, a
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = Box(response) # convert to 'dot-dict' style which is the openai python client default
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
Expand Down
3 changes: 0 additions & 3 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""

import os
from datetime import datetime
import requests
import json
import uuid

from box import Box

from memgpt.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
from memgpt.local_llm.webui.api import get_webui_completion
from memgpt.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
Expand Down
10 changes: 10 additions & 0 deletions memgpt/models/embedding_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List, Literal
from pydantic import BaseModel


class EmbeddingResponse(BaseModel):
"""OpenAI embedding response model: https://platform.openai.com/docs/api-reference/embeddings/object"""

index: int # the index of the embedding in the list of embeddings
embedding: List[float]
object: Literal["embedding"] = "embedding"
13 changes: 10 additions & 3 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,19 @@ def verify_first_message_correctness(
response_message = response.choices[0].message

# First message should be a call to send_message with a non-empty content
if require_send_message and not (response_message.function_call or response_message.tool_calls):
if ("function_call" in response_message and response_message.function_call is not None) and (
"tool_calls" in response_message and response_message.tool_calls is not None
):
printd(f"First message includes both function call AND tool call: {response_message}")
return False
elif "function_call" in response_message and response_message.function_call is not None:
function_call = response_message.function_call
elif "tool_calls" in response_message and response_message.tool_calls is not None:
function_call = response_message.tool_calls[0].function
else:
printd(f"First message didn't include function call: {response_message}")
return False

assert not (response_message.function_call and response_message.tool_calls), response_message
function_call = response_message.function_call if response_message.function_call else response_message.tool_calls[0].function
function_name = function_call.name if function_call is not None else ""
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
Expand Down
Loading