Skip to content

Commit

Permalink
feat: patch missing inner thoughts on new openai models (#1562)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored and sarahwooders committed Aug 17, 2024
1 parent 4983595 commit b1f9ed4
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 18 deletions.
25 changes: 21 additions & 4 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
from memgpt.metadata import MetadataStore
from memgpt.models import chat_completion_response
from memgpt.models.pydantic_models import ToolModel
from memgpt.models.pydantic_models import OptionState, ToolModel
from memgpt.persistence_manager import LocalStateManager
from memgpt.system import (
get_initial_boot_messages,
Expand Down Expand Up @@ -314,6 +314,7 @@ def _get_ai_reply(
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API"""
try:
Expand All @@ -330,6 +331,8 @@ def _get_ai_reply(
# streaming
stream=stream,
stream_inferface=self.interface,
# putting inner thoughts in func args or not
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
Expand Down Expand Up @@ -401,7 +404,7 @@ def _handle_ai_response(
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
try:
function_to_call = self.functions_python[function_name]
except KeyError as e:
except KeyError:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
Expand All @@ -424,7 +427,7 @@ def _handle_ai_response(
try:
raw_function_args = function_call.arguments
function_args = parse_json(raw_function_args)
except Exception as e:
except Exception:
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
function_response = package_function_response(False, error_msg)
messages.append(
Expand Down Expand Up @@ -550,6 +553,7 @@ def step(
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""

Expand Down Expand Up @@ -634,6 +638,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
Expand All @@ -646,6 +651,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)

# Step 2: check if LLM wanted to call a function
Expand Down Expand Up @@ -716,7 +722,18 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
self.summarize_messages_inplace()

# Try step again
return self.step(user_message, first_message=first_message, return_dicts=return_dicts)
return self.step(
user_message,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
return_dicts=return_dicts,
recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
timestamp=timestamp,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)

else:
printd(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e
Expand Down
13 changes: 12 additions & 1 deletion memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from memgpt.memory import ChatMemory
from memgpt.metadata import MetadataStore
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.models.pydantic_models import OptionState
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger

Expand Down Expand Up @@ -410,6 +411,10 @@ def run(
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
# streaming
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
# whether or not to put the inner thoughts inside the function args
no_content: Annotated[
OptionState, typer.Option(help="Set to 'yes' for LLM APIs that omit the `content` field during tool calling")
] = OptionState.DEFAULT,
):
"""Start chatting with an MemGPT agent
Expand Down Expand Up @@ -671,7 +676,13 @@ def run(

print() # extra space
run_agent_loop(
memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream
memgpt_agent=memgpt_agent,
config=config,
first=first,
ms=ms,
no_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs=no_content,
) # TODO: add back no_verify


Expand Down
48 changes: 43 additions & 5 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """

import copy
import json
import uuid
import warnings
from datetime import datetime, timezone
from typing import Dict, List, Optional, TypeVar

Expand Down Expand Up @@ -70,6 +72,27 @@ def to_dict(self):
}


def add_inner_thoughts_to_tool_call(
tool_call: ToolCall,
inner_thoughts: str,
inner_thoughts_key: str,
) -> ToolCall:
"""Add inner thoughts (arg + value) to a tool call"""
# because the kwargs are stored as strings, we need to load then write the JSON dicts
try:
# load the args list
func_args = json.loads(tool_call.function["arguments"])
# add the inner thoughts to the args list
func_args[inner_thoughts_key] = inner_thoughts
# create the updated tool call (as a string)
updated_tool_call = copy.deepcopy(tool_call)
updated_tool_call.function["arguments"] = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
return updated_tool_call
except json.JSONDecodeError as e:
warnings.warn(f"Failed to put inner thoughts in kwargs: {e}")
raise e


class Message(Record):
"""Representation of a message sent.
Expand Down Expand Up @@ -249,12 +272,16 @@ def dict_to_message(
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)

def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
result_json = self.to_openai_dict()
search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}}
return search_result_json

def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
def to_openai_dict(
self,
max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN,
put_inner_thoughts_in_kwargs: bool = True,
) -> dict:
"""Go from Message class to ChatCompletion message object"""

# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
Expand Down Expand Up @@ -282,14 +309,25 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
elif self.role == "assistant":
assert self.tool_calls is not None or self.text is not None
openai_message = {
"content": self.text,
"content": None if put_inner_thoughts_in_kwargs else self.text,
"role": self.role,
}
# Optional fields, do not include if null
if self.name is not None:
openai_message["name"] = self.name
if self.tool_calls is not None:
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if put_inner_thoughts_in_kwargs:
# put the inner thoughts inside the tool call before casting to a dict
openai_message["tool_calls"] = [
add_inner_thoughts_to_tool_call(
tool_call,
inner_thoughts=self.text,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
).to_dict()
for tool_call in self.tool_calls
]
else:
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if max_tool_id_length:
for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
Expand Down Expand Up @@ -453,7 +491,7 @@ def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)

if self.name is None:
raise UserWarning(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
function_name = self.tool_call_id
else:
function_name = self.name
Expand Down
Loading

0 comments on commit b1f9ed4

Please sign in to comment.